Performs matrix multiplication with a quantized weight matrix. This operation is essential for efficient inference with quantized models, significantly reducing memory usage and computation time while maintaining reasonable accuracy.
Usage
mlx_quantized_matmul(
x,
w,
scales = NULL,
biases = NULL,
transpose = TRUE,
group_size = 64L,
bits = 4L,
mode = "affine",
device = mlx_default_device()
)Arguments
- x
An mlx array.
- w
An mlx array representing the weight matrix. Accepts either an unquantized matrix (which may be quantized automatically) or a pre-quantized uint32 matrix produced by
mlx_quantize().- scales
An optional mlx array of quantization scales. Required when
wis already quantized.- biases
An optional mlx array of quantization biases (affine mode); use
NULLfor symmetric quantization.- transpose
Whether to transpose the weight matrix before multiplication.
- group_size
The group size for quantization. Smaller groups improve accuracy at the cost of slightly higher memory. Default: 64.
- bits
Number of bits for quantization (typically 4 or 8). Default: 4.
- mode
Quantization mode, either
"affine"or"mxfp4".- device
Execution target: supply
"gpu","cpu", or anmlx_streamcreated viamlx_new_stream(). Defaults to the currentmlx_default_device()unless noted otherwise (helpers that act on an existing array typically reuse that array's device or stream).
Details
Quantized matrix multiplication uses low-precision representations (typically 4-bit or 8-bit integers) for weights, which reduces memory footprint by up to 8x compared to float32. The scales parameter contains the dequantization factors needed to reconstruct approximate float values during computation.
The group_size parameter controls the granularity of quantization - smaller groups provide better accuracy but slightly higher memory usage.
Automatic Quantization: If only w is provided (without scales), the function will
automatically quantize w using mlx_quantize() before performing the multiplication.
For repeated operations, it's more efficient to pre-quantize weights once using
mlx_quantize() and reuse them.
Examples
# Automatic quantization (convenient but slower for repeated use)
x <- mlx_rand_normal(c(4, 64))
w <- mlx_rand_normal(c(128, 64))
result <- mlx_quantized_matmul(x, w, group_size = 32)
# Pre-quantized weights (faster for repeated operations)
quant <- mlx_quantize(w, group_size = 32, bits = 4)
result <- mlx_quantized_matmul(x, quant$w_q, quant$scales, quant$biases, group_size = 32)