Skip to content

vllm.model_executor.models.qwen3_5_mtp

Inference-only Qwen3_5 MTP model.

logger module-attribute

logger = init_logger(__name__)

Qwen3_5MTP

Bases: Module, SupportsMultiModal

Source code in vllm/model_executor/models/qwen3_5_mtp.py
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        "hidden_states": 0,
    }
)
class Qwen3_5MTP(nn.Module, SupportsMultiModal):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": ["up_proj", "down_proj"],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_text_config
        self.vllm_config = vllm_config
        cache_config = vllm_config.cache_config
        if cache_config.mamba_cache_mode == "all":
            raise NotImplementedError(
                "Qwen3_5MTP currently does not support 'all' prefix caching, "
                "please use '--mamba-cache-mode=align' instead"
            )

        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.model = Qwen3_5MultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
        )

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    prefix=maybe_prefix(prefix, "lm_head"),
                )
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        inputs_embeds = self._embed_text_input_ids(
            input_ids,
            self.model.embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        is_multimodal = _require_is_multimodal(is_multimodal)

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ):
        hidden_states = self.model(
            input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor | None:
        return self.logits_processor(self.lm_head, hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        def remap_weight_names(weights):
            for name, weight in weights:
                if name.startswith("mtp."):
                    name = name.replace("mtp.", "model.")
                elif any(key in name for key in ["embed_tokens", "lm_head"]):
                    if "embed_tokens" in name:
                        name = name.replace("language_model.", "")
                else:
                    continue
                yield name, weight

        loader = AutoWeightsLoader(self)
        return loader.load_weights(remap_weight_names(weights))

config instance-attribute

config = config

lm_head instance-attribute

lm_head = embed_tokens

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

model instance-attribute

model = Qwen3_5MultiTokenPredictor(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "mtp"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["up_proj", "down_proj"],
}

quant_config instance-attribute

quant_config = quant_config

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    config = vllm_config.model_config.hf_text_config
    self.vllm_config = vllm_config
    cache_config = vllm_config.cache_config
    if cache_config.mamba_cache_mode == "all":
        raise NotImplementedError(
            "Qwen3_5MTP currently does not support 'all' prefix caching, "
            "please use '--mamba-cache-mode=align' instead"
        )

    self.quant_config = vllm_config.quant_config

    super().__init__()
    self.config = config
    self.model = Qwen3_5MultiTokenPredictor(
        vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
    )

    if get_pp_group().is_last_rank:
        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
    else:
        self.lm_head = PPMissingLayer()

    self.logits_processor = LogitsProcessor(config.vocab_size)

compute_logits

compute_logits(
    hidden_states: Tensor, spec_step_idx: int = 0
) -> Tensor | None
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
    spec_step_idx: int = 0,
) -> torch.Tensor | None:
    return self.logits_processor(self.lm_head, hidden_states)

embed_input_ids

embed_input_ids(
    input_ids: Tensor,
    multimodal_embeddings: MultiModalEmbeddings
    | None = None,
    *,
    is_multimodal: Tensor | None = None,
    handle_oov_mm_token: bool = False,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def embed_input_ids(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: MultiModalEmbeddings | None = None,
    *,
    is_multimodal: torch.Tensor | None = None,
    handle_oov_mm_token: bool = False,
) -> torch.Tensor:
    inputs_embeds = self._embed_text_input_ids(
        input_ids,
        self.model.embed_input_ids,
        is_multimodal=is_multimodal,
        handle_oov_mm_token=handle_oov_mm_token,
    )

    if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
        return inputs_embeds

    is_multimodal = _require_is_multimodal(is_multimodal)

    inputs_embeds = _merge_multimodal_embeddings(
        inputs_embeds=inputs_embeds,
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
    )

    return inputs_embeds

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    hidden_states: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
)
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
):
    hidden_states = self.model(
        input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
    )
    return hidden_states

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    def remap_weight_names(weights):
        for name, weight in weights:
            if name.startswith("mtp."):
                name = name.replace("mtp.", "model.")
            elif any(key in name for key in ["embed_tokens", "lm_head"]):
                if "embed_tokens" in name:
                    name = name.replace("language_model.", "")
            else:
                continue
            yield name, weight

    loader = AutoWeightsLoader(self)
    return loader.load_weights(remap_weight_names(weights))

Qwen3_5MoeMTP

Bases: Qwen3_5MTP, QwenNextMixtureOfExperts

Source code in vllm/model_executor/models/qwen3_5_mtp.py
class Qwen3_5MoeMTP(Qwen3_5MTP, QwenNextMixtureOfExperts):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.set_moe_parameters()

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__(vllm_config=vllm_config, prefix=prefix)
    self.set_moe_parameters()

Qwen3_5MultiTokenPredictor

Bases: Module

Source code in vllm/model_executor/models/qwen3_5_mtp.py
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        "hidden_states": 0,
    }
)
class Qwen3_5MultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        model_config = vllm_config.model_config
        quant_config = vllm_config.quant_config

        config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config

        self.config = config

        self.vocab_size = config.vocab_size

        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1)

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        self.fc = ColumnParallelLinear(
            self.config.hidden_size * 2,
            self.config.hidden_size,
            gather_output=True,
            bias=False,
            return_bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.fc",
        )

        self.layers = torch.nn.ModuleList(
            Qwen3_5DecoderLayer(
                vllm_config,
                layer_type="full_attention",
                prefix=f"{prefix}.layers.{idx}",
            )
            for idx in range(self.num_mtp_layers)
        )

        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )

        self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.pre_fc_norm_hidden = Qwen3_5RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        self.pre_fc_norm_embedding = Qwen3_5RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.embed_input_ids(input_ids)
            assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
            inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
            hidden_states = self.pre_fc_norm_hidden(hidden_states)
            hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
            hidden_states = self.fc(hidden_states)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        current_step_idx = spec_step_idx % self.num_mtp_layers
        hidden_states, residual = self.layers[current_step_idx](
            positions=positions,
            hidden_states=hidden_states,
            residual=residual,
        )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def load_fused_expert_weights(
        self,
        name: str,
        params_dict: dict,
        loaded_weight: torch.Tensor,
        shard_id: str,
        num_experts: int,
    ) -> bool:
        param = params_dict[name]
        weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
        loaded_local_expert = False
        for expert_id in range(num_experts):
            curr_expert_weight = loaded_weight[expert_id]
            success = weight_loader(
                param,
                curr_expert_weight,
                name,
                shard_id,
                expert_id,
                return_success=True,
            )
            if success:
                loaded_local_expert = True

        return loaded_local_expert

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts
            if hasattr(self.config, "num_experts")
            else 0,
        )

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        is_fused_expert = False
        fused_expert_params_mapping = [
            ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
            ("experts.w2_weight", "experts.down_proj", 0, "w2"),
        ]
        num_experts = (
            self.config.num_experts if hasattr(self.config, "num_experts") else 0
        )
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if "experts.gate_up_proj" in name or "experts.down_proj" in name:
                    is_fused_expert = True
                    expert_params_mapping = fused_expert_params_mapping

                if weight_name not in name:
                    continue

                if "mlp.experts" in name:
                    continue

                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                is_expert_weight = False
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    is_expert_weight = True
                    name_mapped = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name_mapped, self):
                        continue
                    if is_fused_expert:
                        # qwen3.5 no need to transpose
                        # loaded_weight = loaded_weight.transpose(-1, -2)
                        if "experts.gate_up_proj" in name:
                            loaded_weight = loaded_weight.chunk(2, dim=-2)
                            success_w1 = self.load_fused_expert_weights(
                                name_mapped,
                                params_dict,
                                loaded_weight[0],
                                "w1",
                                num_experts,
                            )
                            success_w3 = self.load_fused_expert_weights(
                                name_mapped,
                                params_dict,
                                loaded_weight[1],
                                "w3",
                                num_experts,
                            )
                            success = success_w1 and success_w3
                        else:
                            # down_proj
                            success = self.load_fused_expert_weights(
                                name_mapped,
                                params_dict,
                                loaded_weight,
                                shard_id,
                                num_experts,
                            )
                        if success:
                            name = name_mapped
                            break
                    else:
                        # Skip loading extra bias for GPTQ models.
                        if (
                            name_mapped.endswith(".bias")
                            or name_mapped.endswith("_bias")
                        ) and name_mapped not in params_dict:
                            continue
                        param = params_dict[name_mapped]
                        weight_loader = param.weight_loader
                        success = weight_loader(
                            param,
                            loaded_weight,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                    if success:
                        name = name_mapped
                        break
                else:
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    if name not in params_dict:
                        logger.warning_once(
                            f"Parameter {name} not found in params_dict, skip loading"
                        )
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size
)

fc instance-attribute

fc = ColumnParallelLinear(
    hidden_size * 2,
    hidden_size,
    gather_output=True,
    bias=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.fc",
)

layers instance-attribute

layers = ModuleList(
    (
        Qwen3_5DecoderLayer(
            vllm_config,
            layer_type="full_attention",
            prefix=f"{prefix}.layers.{idx}",
        )
    )
    for idx in (range(num_mtp_layers))
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

mtp_start_layer_idx instance-attribute

mtp_start_layer_idx = num_hidden_layers

norm instance-attribute

norm = Qwen3_5RMSNorm(hidden_size, eps=rms_norm_eps)

num_mtp_layers instance-attribute

num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1)

pre_fc_norm_embedding instance-attribute

pre_fc_norm_embedding = Qwen3_5RMSNorm(
    hidden_size, eps=rms_norm_eps
)

pre_fc_norm_hidden instance-attribute

pre_fc_norm_hidden = Qwen3_5RMSNorm(
    hidden_size, eps=rms_norm_eps
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    model_config = vllm_config.model_config
    quant_config = vllm_config.quant_config

    config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config

    self.config = config

    self.vocab_size = config.vocab_size

    self.mtp_start_layer_idx = config.num_hidden_layers
    self.num_mtp_layers = getattr(config, "mtp_num_hidden_layers", 1)

    self.embed_tokens = VocabParallelEmbedding(
        self.vocab_size,
        config.hidden_size,
    )

    self.fc = ColumnParallelLinear(
        self.config.hidden_size * 2,
        self.config.hidden_size,
        gather_output=True,
        bias=False,
        return_bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.fc",
    )

    self.layers = torch.nn.ModuleList(
        Qwen3_5DecoderLayer(
            vllm_config,
            layer_type="full_attention",
            prefix=f"{prefix}.layers.{idx}",
        )
        for idx in range(self.num_mtp_layers)
    )

    self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], config.hidden_size
    )

    self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    self.pre_fc_norm_hidden = Qwen3_5RMSNorm(
        config.hidden_size, eps=config.rms_norm_eps
    )
    self.pre_fc_norm_embedding = Qwen3_5RMSNorm(
        config.hidden_size, eps=config.rms_norm_eps
    )

embed_input_ids

embed_input_ids(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    hidden_states: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    spec_step_idx: int = 0,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    spec_step_idx: int = 0,
) -> torch.Tensor:
    if get_pp_group().is_first_rank:
        if inputs_embeds is None:
            inputs_embeds = self.embed_input_ids(input_ids)
        assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
        inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
        hidden_states = self.pre_fc_norm_hidden(hidden_states)
        hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
        hidden_states = self.fc(hidden_states)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    current_step_idx = spec_step_idx % self.num_mtp_layers
    hidden_states, residual = self.layers[current_step_idx](
        positions=positions,
        hidden_states=hidden_states,
        residual=residual,
    )

    if not get_pp_group().is_last_rank:
        return IntermediateTensors(
            {"hidden_states": hidden_states, "residual": residual}
        )

    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

load_fused_expert_weights

load_fused_expert_weights(
    name: str,
    params_dict: dict,
    loaded_weight: Tensor,
    shard_id: str,
    num_experts: int,
) -> bool
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def load_fused_expert_weights(
    self,
    name: str,
    params_dict: dict,
    loaded_weight: torch.Tensor,
    shard_id: str,
    num_experts: int,
) -> bool:
    param = params_dict[name]
    weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
    loaded_local_expert = False
    for expert_id in range(num_experts):
        curr_expert_weight = loaded_weight[expert_id]
        success = weight_loader(
            param,
            curr_expert_weight,
            name,
            shard_id,
            expert_id,
            return_success=True,
        )
        if success:
            loaded_local_expert = True

    return loaded_local_expert

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/qwen3_5_mtp.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
        ("gate_up_proj", "gate_proj", 0),
        ("gate_up_proj", "up_proj", 1),
    ]

    # Params for weights, fp8 weight scales, fp8 activation scales
    # (param_name, weight_name, expert_id, shard_id)
    expert_params_mapping = FusedMoE.make_expert_params_mapping(
        self,
        ckpt_gate_proj_name="gate_proj",
        ckpt_down_proj_name="down_proj",
        ckpt_up_proj_name="up_proj",
        num_experts=self.config.num_experts
        if hasattr(self.config, "num_experts")
        else 0,
    )

    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    is_fused_expert = False
    fused_expert_params_mapping = [
        ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
        ("experts.w2_weight", "experts.down_proj", 0, "w2"),
    ]
    num_experts = (
        self.config.num_experts if hasattr(self.config, "num_experts") else 0
    )
    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue

        for param_name, weight_name, shard_id in stacked_params_mapping:
            if "experts.gate_up_proj" in name or "experts.down_proj" in name:
                is_fused_expert = True
                expert_params_mapping = fused_expert_params_mapping

            if weight_name not in name:
                continue

            if "mlp.experts" in name:
                continue

            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue
            if name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            is_expert_weight = False
            for mapping in expert_params_mapping:
                param_name, weight_name, expert_id, shard_id = mapping
                if weight_name not in name:
                    continue
                is_expert_weight = True
                name_mapped = name.replace(weight_name, param_name)
                # Skip layers on other devices.
                if is_pp_missing_parameter(name_mapped, self):
                    continue
                if is_fused_expert:
                    # qwen3.5 no need to transpose
                    # loaded_weight = loaded_weight.transpose(-1, -2)
                    if "experts.gate_up_proj" in name:
                        loaded_weight = loaded_weight.chunk(2, dim=-2)
                        success_w1 = self.load_fused_expert_weights(
                            name_mapped,
                            params_dict,
                            loaded_weight[0],
                            "w1",
                            num_experts,
                        )
                        success_w3 = self.load_fused_expert_weights(
                            name_mapped,
                            params_dict,
                            loaded_weight[1],
                            "w3",
                            num_experts,
                        )
                        success = success_w1 and success_w3
                    else:
                        # down_proj
                        success = self.load_fused_expert_weights(
                            name_mapped,
                            params_dict,
                            loaded_weight,
                            shard_id,
                            num_experts,
                        )
                    if success:
                        name = name_mapped
                        break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if (
                        name_mapped.endswith(".bias")
                        or name_mapped.endswith("_bias")
                    ) and name_mapped not in params_dict:
                        continue
                    param = params_dict[name_mapped]
                    weight_loader = param.weight_loader
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
                if success:
                    name = name_mapped
                    break
            else:
                if is_expert_weight:
                    # We've checked that this is an expert weight
                    # However it's not mapped locally to this rank
                    # So we simply skip it
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                if name not in params_dict:
                    logger.warning_once(
                        f"Parameter {name} not found in params_dict, skip loading"
                    )
                    continue
                param = params_dict[name]
                weight_loader = getattr(
                    param, "weight_loader", default_weight_loader
                )
                weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params