Skip to contents

Samples indices from categorical distributions. Each row (or slice along the specified axis) represents a separate categorical distribution over classes.

Usage

mlx_rand_categorical(logits, axis = NULL, num_samples = 1L)

Arguments

logits

A matrix or mlx array of log-probabilities. The values don't need to be normalized (the function applies softmax internally). For a single distribution over K classes, use a 1×K matrix. For multiple independent distributions, use an N×K matrix where each row is a distribution.

axis

Axis (1-indexed) along which to sample. Omit the argument to use the last dimension (typically the class dimension).

num_samples

Number of samples to draw from each distribution.

Value

An mlx array of integer indices (1-indexed) sampled from the categorical distributions.

Examples

# Single distribution over 3 classes
logits <- matrix(c(0.5, 0.2, 0.3), 1, 3)
samples <- mlx_rand_categorical(logits, num_samples = 10)

# Multiple distributions
logits <- matrix(c(1, 2, 3,
                   3, 2, 1), nrow = 2, byrow = TRUE)
samples <- mlx_rand_categorical(logits, num_samples = 5)