Performs quantized matrix multiplication with optional gather operations on inputs. This is useful for combining embedding lookups with quantized linear transformations, a common pattern in transformer models.
Usage
mlx_gather_qmm(
x,
w,
scales,
biases = NULL,
lhs_indices = NULL,
rhs_indices = NULL,
transpose = TRUE,
group_size = 64L,
bits = 4L,
mode = "affine",
sorted_indices = FALSE,
device = mlx_default_device()
)Arguments
- x
An mlx array.
- w
An mlx array representing the weight matrix. Accepts either an unquantized matrix (which may be quantized automatically) or a pre-quantized uint32 matrix produced by
mlx_quantize().- scales
An optional mlx array of quantization scales. Required when
wis already quantized.- biases
An optional mlx array of quantization biases (affine mode); use
NULLfor symmetric quantization.- lhs_indices
An optional integer vector/array (1-indexed) or
mlxtensor of indices for gathering fromx's leading (batch) dimension. Default: NULL- rhs_indices
An optional integer vector/array (1-indexed) or
mlxtensor of indices for gathering fromw's leading (batch) dimension. Default: NULL- transpose
Whether to transpose the weight matrix before multiplication.
- group_size
The group size for quantization. Smaller groups improve accuracy at the cost of slightly higher memory. Default: 64.
- bits
Number of bits for quantization (typically 4 or 8). Default: 4.
- mode
Quantization mode, either
"affine"or"mxfp4".- sorted_indices
Whether supplied indices are sorted (enables optimizations in gather-based kernels).
- device
Execution target: supply
"gpu","cpu", or anmlx_streamcreated viamlx_new_stream(). Defaults to the currentmlx_default_device()unless noted otherwise (helpers that act on an existing array typically reuse that array's device or stream).
Details
This function combines gather operations (indexed lookups) with quantized matrix
multiplication. When lhs_indices is provided, it performs x[lhs_indices] before
the multiplication. Similarly, rhs_indices gathers from the weight matrix.
This is particularly efficient for transformer models where you need to look up token embeddings and then apply a quantized linear transformation in one fused operation.