Skip to contents

Softmax for mlx arrays

Usage

mlx_softmax(x, axes = NULL, precise = FALSE)

Arguments

x

An mlx array, or an R array/matrix/vector that will be converted via as_mlx().

axes

Integer vector of axes (1-indexed). Supply positive integers between 1 and the array rank. Many helpers interpret NULL to mean "all axes"—see the function details for specifics.

precise

Logical; compute in higher precision for stability.

Value

An mlx array with normalized probabilities.

See also

Examples

x <- mlx_matrix(1:6, 2, 3)
sm <- mlx_softmax(x, axes = 2)
rowSums(sm)
#> mlx array [2]
#>   dtype: float32
#>   device: gpu
#>   values:
#> [1] 1 1