Single training step helper
Arguments
- module
An
mlx_module.- loss_fn
Function of
moduleand data returning an mlx scalar.- optimizer
Optimizer object from
mlx_optimizer_sgd().- ...
Additional data passed to
loss_fn.
Examples
set.seed(1)
model <- mlx_linear(2, 1, bias = FALSE)
opt <- mlx_optimizer_sgd(mlx_parameters(model), lr = 0.1)
data_x <- as_mlx(matrix(c(1, 2, 3, 4), 2, 2))
data_y <- as_mlx(matrix(c(5, 6), 2, 1))
loss_fn <- function(model, x, y) {
pred <- model$forward(x)
mean((pred - y)^2)
}
result <- mlx_train_step(model, loss_fn, opt, data_x, data_y)