Samples indices from categorical distributions. Each row (or slice along the specified axis) represents a separate categorical distribution over classes.
Arguments
- logits
A matrix or mlx array of log-probabilities. The values don't need to be normalized (the function applies softmax internally). For a single distribution over K classes, use a 1×K matrix. For multiple independent distributions, use an N×K matrix where each row is a distribution.
- axis
Axis (1-indexed) along which to sample. Omit the argument to use the last dimension (typically the class dimension).
- num_samples
Number of samples to draw from each distribution.