Skip to contents

Single training step helper

Usage

mlx_train_step(module, loss_fn, optimizer, ...)

Arguments

module

An mlx_module.

loss_fn

Function of module and data returning an mlx scalar.

optimizer

Optimizer object from mlx_optimizer_sgd().

...

Additional data passed to loss_fn.

Value

A list with the current loss.

Examples

set.seed(1)
model <- mlx_linear(2, 1, bias = FALSE)
opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1)
data_x <- as_mlx(matrix(c(1, 2, 3, 4), 2, 2))
data_y <- as_mlx(matrix(c(1, 2), 2, 1))
loss_fn <- function(mod, x, y) {
  preds <- mlx_forward(mod, x)
  diff <- preds - y
  sum(diff * diff)
}
mlx_train_step(model, loss_fn, opt, data_x, data_y)
#> $loss
#> mlx array []
#>   dtype: float32
#>   device: gpu
#>   values:
#> [1] 7.49876
#>