Skip to contents

Computes cross-entropy loss for multi-class classification.

Usage

mlx_cross_entropy(logits, targets, reduction = c("mean", "sum", "none"))

Arguments

logits

Unnormalized predictions (logits) as an mlx array.

targets

Target class indices as an mlx array or integer vector.

reduction

Type of reduction: "mean" (default), "sum", or "none".

Value

An mlx array containing the loss.

Examples

# Logits for 3 samples, 4 classes
logits <- as_mlx(matrix(rnorm(12), 3, 4))
targets <- as_mlx(c(1, 3, 2))  # 0-indexed class labels
mlx_cross_entropy(logits, targets)
#> mlx array []
#>   dtype: float32
#>   device: gpu
#>   values:
#> [1] 1.120893