@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
else:
self.rocm_aiter_fused_experts = None # type: ignore
# Flashinfer TRTLLM MoE is only supported on Blackwell and later GPUS
self.flashinfer_trtllm_moe_enabled = (
has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
and current_platform.has_device_capability(100)
)
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self.flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
and current_platform.has_device_capability(90)
)
if self.flashinfer_trtllm_moe_enabled:
logger.info_once(
"Enabling FlashInfer TRTLLM MoE for UnquantizedFusedMoEMethod"
)
# Import the module to register the custom op
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
self.flashinfer_trtllm_moe = torch.ops.vllm.flashinfer_fused_moe_bf16
elif self.flashinfer_cutlass_moe_enabled:
logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from functools import partial
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
self.flashinfer_cutlass_moe = partial(
flashinfer_cutlass_moe,
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
)
else:
if (
self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
):
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",
)
elif self.moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
self.flashinfer_cutlass_moe = None # type: ignore
@property
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts(self.moe_quant_config)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.moe.is_act_and_mul:
w13_up_dim = 2 * intermediate_size_per_partition
else:
w13_up_dim = intermediate_size_per_partition
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_up_dim,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(num_experts, w13_up_dim, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (
envs.VLLM_ROCM_MOE_PADDING
and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
if self.flashinfer_trtllm_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
w13_weights_shuffled, w2_weights_shuffled = (
convert_moe_weights_to_flashinfer_trtllm_block_layout(
self._cache_permute_indices,
layer.w13_weight.data,
layer.w2_weight.data,
)
)
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
if self.flashinfer_cutlass_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
experts_start_id=ep_rank_start,
)
elif current_platform.is_cpu():
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
dtype_w13 = layer.w13_weight.dtype
_, n_w13, k_w13 = layer.w13_weight.size()
dtype_w2 = layer.w2_weight.dtype
_, n_w2, k_w2 = layer.w2_weight.size()
if (
envs.VLLM_CPU_SGL_KERNEL
and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13)
and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)
):
packed_w13_weight = torch.ops._C.convert_weight_packed(
layer.w13_weight
)
assert packed_w13_weight.size() == layer.w13_weight.size()
layer.w13_weight.copy_(packed_w13_weight)
del packed_w13_weight
packed_w2_weight = torch.ops._C.convert_weight_packed(
layer.w2_weight
)
assert packed_w2_weight.size() == layer.w2_weight.size()
layer.w2_weight.copy_(packed_w2_weight)
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
global_num_experts=global_num_experts,
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_eplb=enable_eplb,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if self.flashinfer_trtllm_moe_enabled:
return self.flashinfer_trtllm_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=x,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routing_method_type=layer.routing_method_type,
tune_max_num_tokens=8192,
)
if self.rocm_aiter_moe_enabled:
result = self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
def forward_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
enable_eplb is not False
or expert_load_view is not None
or logical_to_physical_map is not None
or logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
routed_scaling_factor,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
)
def forward_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
enable_eplb is not False
or expert_load_view is not None
or logical_to_physical_map is not None
or logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function=custom_routing_function,
)
def forward_tpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU."
)
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU."
)
assert activation == "silu", f"{activation} is not supported for TPU."
assert routed_scaling_factor == 1.0, (
f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU."
)
if (
enable_eplb is not False
or expert_load_view is not None
or logical_to_physical_map is not None
or logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
global_num_experts=global_num_experts,
expert_map=expert_map,
renormalize=renormalize,
)
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu
elif current_platform.is_xpu():
forward_native = forward_xpu
else:
forward_native = forward_cuda