class TritonOrDeepGemmExperts(FallbackExperts):
"""DeepGemm with fallback to Triton for low latency shapes."""
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(
experts=DeepGemmExperts(moe_config, quant_config),
fallback_experts=TritonExperts(moe_config, quant_config),
)
@staticmethod
def get_clses() -> tuple[
type[mk.FusedMoEPermuteExpertsUnpermute],
type[mk.FusedMoEPermuteExpertsUnpermute],
]:
return (DeepGemmExperts, TritonExperts)
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
return self.experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
else:
return self.fallback_experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
def _select_experts_impl(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts
else:
return self.fallback_experts