Computes cross-entropy loss for multi-class classification.
Usage
mlx_cross_entropy(logits, targets, reduction = c("mean", "sum", "none"))Examples
# Logits for 3 samples, 4 classes
logits <- mlx_matrix(rnorm(12), 3, 4)
targets <- as_mlx(c(1, 3, 2))
mlx_cross_entropy(logits, targets)
#> mlx array []
#> dtype: float32
#> device: gpu
#> values:
#> [1] 2.377907