Single training step helper
mlx_train_step.RdSingle training step helper
Arguments
- module
- An - mlx_module.
- loss_fn
- Function of - moduleand data returning an mlx scalar.
- optimizer
- Optimizer object from - mlx_optimizer_sgd().
- ...
- Additional data passed to - loss_fn.
Examples
set.seed(1)
model <- mlx_linear(2, 1, bias = FALSE)
opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1)
data_x <- as_mlx(matrix(c(1, 2, 3, 4), 2, 2))
data_y <- as_mlx(matrix(c(1, 2), 2, 1))
loss_fn <- function(mod, x, y) {
  preds <- mlx_forward(mod, x)
  diff <- preds - y
  sum(diff * diff)
}
mlx_train_step(model, loss_fn, opt, data_x, data_y)
#> $loss
#> mlx array []
#>   dtype: float32
#>   device: gpu
#>   values:
#> [1] 7.49876
#>