Skip to content

vllm.model_executor.layers.quantization.utils.mxfp8_utils

MXFP8_BLOCK_SIZE module-attribute

MXFP8_BLOCK_SIZE = 32

MXFP8_SCALE_DTYPE module-attribute

MXFP8_SCALE_DTYPE = uint8

MXFP8_VALUE_DTYPE module-attribute

MXFP8_VALUE_DTYPE = float8_e4m3fn

logger module-attribute

logger = init_logger(__name__)

Mxfp8LinearBackend

Bases: Enum

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
class Mxfp8LinearBackend(Enum):
    EMULATION = "emulation"

EMULATION class-attribute instance-attribute

EMULATION = 'emulation'

Mxfp8LinearOp

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
class Mxfp8LinearOp:
    def __init__(self, backend: Mxfp8LinearBackend):
        if backend not in Mxfp8LinearBackend:
            raise ValueError(f"Unsupported backend: {backend}")

        self.backend = backend

    def apply(
        self,
        input: torch.Tensor,
        weight: torch.Tensor,
        weight_scale: torch.Tensor,
        out_dtype: torch.dtype,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Validate weight_scale dtype and shape (must be 2D for TORCH backend)
        if weight_scale.dtype != MXFP8_SCALE_DTYPE:
            raise ValueError(
                f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
                f"got {weight_scale.dtype}."
            )
        if weight_scale.ndim != 2:
            raise ValueError(
                f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. "
                f"Ensure process_weights_after_loading was called."
            )

        weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale)

        output = torch.nn.functional.linear(input, weight_bf16, bias)
        return output.to(out_dtype)

backend instance-attribute

backend = backend

__init__

__init__(backend: Mxfp8LinearBackend)
Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def __init__(self, backend: Mxfp8LinearBackend):
    if backend not in Mxfp8LinearBackend:
        raise ValueError(f"Unsupported backend: {backend}")

    self.backend = backend

apply

apply(
    input: Tensor,
    weight: Tensor,
    weight_scale: Tensor,
    out_dtype: dtype,
    bias: Tensor | None = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def apply(
    self,
    input: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    out_dtype: torch.dtype,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    # Validate weight_scale dtype and shape (must be 2D for TORCH backend)
    if weight_scale.dtype != MXFP8_SCALE_DTYPE:
        raise ValueError(
            f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
            f"got {weight_scale.dtype}."
        )
    if weight_scale.ndim != 2:
        raise ValueError(
            f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. "
            f"Ensure process_weights_after_loading was called."
        )

    weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale)

    output = torch.nn.functional.linear(input, weight_bf16, bias)
    return output.to(out_dtype)

_mxfp8_e4m3_quantize_impl

_mxfp8_e4m3_quantize_impl(
    x: Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def _mxfp8_e4m3_quantize_impl(
    x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize

    x_q, x_scales = flashinfer_mxfp8_quantize(
        x, is_sf_swizzled_layout=is_sf_swizzled_layout
    )
    if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
        x_scales = x_scales.view(x.size(0), -1)
    return x_q, x_scales

dequant_mxfp8_to_bf16

dequant_mxfp8_to_bf16(x: Tensor, scales: Tensor) -> Tensor

Dequantize MXFP8 tensor to BF16.

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
    """Dequantize MXFP8 tensor to BF16."""
    x_float = x.to(torch.float32)

    num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE
    x_blocked = x_float.view(*x.shape[:-1], num_blocks, MXFP8_BLOCK_SIZE)

    descale = torch.exp2(scales.to(torch.float32) - 127.0)

    dequantized = x_blocked * descale.unsqueeze(-1)

    dequantized = dequantized.view(*x.shape)

    return dequantized.to(torch.bfloat16)

mxfp8_e4m3_quantize

mxfp8_e4m3_quantize(
    x: Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def mxfp8_e4m3_quantize(
    x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout)

mxfp8_e4m3_quantize_fake

mxfp8_e4m3_quantize_fake(
    x: Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[Tensor, Tensor]

Fake implementation for torch.compile tracing.

Source code in vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
def mxfp8_e4m3_quantize_fake(
    x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile tracing."""
    fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)

    block_size = MXFP8_BLOCK_SIZE

    if x.ndim == 2:
        M, N = x.shape
        K = (N + block_size - 1) // block_size
        if is_sf_swizzled_layout:
            M_padded = ((M + 127) // 128) * 128
            K_padded = ((K + 3) // 4) * 4
            scales = torch.empty(
                M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
            )
        else:
            scales = torch.empty((M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
    elif x.ndim == 3:
        B, M, N = x.shape
        K = (N + block_size - 1) // block_size
        if is_sf_swizzled_layout:
            M_padded = ((M + 127) // 128) * 128
            K_padded = ((K + 3) // 4) * 4
            scales = torch.empty(
                B * M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
            )
        else:
            scales = torch.empty((B, M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
    else:
        scale_shape = list(x.shape)
        scale_shape[-1] = (x.shape[-1] + block_size - 1) // block_size
        scales = torch.empty(scale_shape, dtype=MXFP8_SCALE_DTYPE, device=x.device)

    return fp_data, scales