Skip to content

vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils

dequantize_to_dtype

dequantize_to_dtype(
    tensor_fp4,
    tensor_sf,
    global_scale,
    dtype,
    device,
    block_size=16,
)

Dequantize the fp4 tensor back to high precision.

Source code in vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
def dequantize_to_dtype(
    tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
    """Dequantize the fp4 tensor back to high precision."""
    # Two fp4 values are packed into one uint8.
    assert tensor_fp4.dtype == torch.uint8
    m, packed_k = tensor_fp4.shape
    k = packed_k * 2
    tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
    tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
    tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
    tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
    tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale

    # scale the tensor
    out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
    return out.to(dtype)