vllm.distributed.tpu_distributed_utils ¶
   MODULE_TYPE_TO_WRAPPING_FUNC  module-attribute  ¶
 MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
    [
        (
            "QKVParallelLinear",
            partition_qkv_parallel_linear,
        ),
        (
            "ColumnParallelLinear",
            partition_column_parallel_linear,
        ),
        (
            "RowParallelLinear",
            partition_row_parallel_linear,
        ),
    ]
)
  XlaQKVParallelLinear ¶
  Bases: Module
Source code in vllm/distributed/tpu_distributed_utils.py
   __init__ ¶
  Source code in vllm/distributed/tpu_distributed_utils.py
   _load_weights_from_qkv_linear ¶
 _load_weights_from_qkv_linear(qkv_linear: Module)
Source code in vllm/distributed/tpu_distributed_utils.py
   _shard_weight ¶
  Source code in vllm/distributed/tpu_distributed_utils.py
   forward ¶
  Source code in vllm/distributed/tpu_distributed_utils.py
   get_fqn ¶
     partition_column_parallel_linear ¶
  Source code in vllm/distributed/tpu_distributed_utils.py
    partition_qkv_parallel_linear ¶
  Source code in vllm/distributed/tpu_distributed_utils.py
    partition_row_parallel_linear ¶
  Source code in vllm/distributed/tpu_distributed_utils.py
    shard_model ¶
 shard_model(model: Module, mesh: Mesh) -> None
Recursively check a PyTorch model and apply appropriate sharding based on the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
 model  |   Module  |    torch.nn.Module to process  |  required | 
 mesh  |   Mesh  |    An XLA SPMD mesh object used for sharding  |  required |