Sample from a categorical distribution on mlx arrays
mlx_rand_categorical.RdSamples indices from categorical distributions. Each row (or slice along the specified axis) represents a separate categorical distribution over classes.
Arguments
- logits
- A matrix or mlx array of log-probabilities. The values don't need to be normalized (the function applies softmax internally). For a single distribution over K classes, use a 1×K matrix. For multiple independent distributions, use an N×K matrix where each row is a distribution. 
- axis
- The axis along which to sample. Default is -1 (last axis, typically the class dimension). 
- num_samples
- Number of samples to draw from each distribution.