mlx_grad() computes gradients of an R function that operates on mlx
arrays. The function must keep all differentiable computations in MLX
(e.g., via as_mlx() and MLX operators) and return an mlx object.
Arguments
- f
An R function. Its arguments should be mlx objects, and its return value must be an mlx array (typically a scalar loss).
- ...
Arguments to pass to
f. They will be coerced to mlx if needed.- argnums
Indices (1-based) identifying which arguments to differentiate with respect to. Defaults to all arguments.
- value
Should the function value be returned alongside gradients? Set to
TRUEto receive a list with componentsvalueandgrads.
Value
When value = FALSE (default), a list of mlx arrays containing the
gradients in the same order as argnums. When value = TRUE, a list with
elements value (the function output as mlx) and grads.
Details
Keep the differentiated closure inside MLX operations. Coercing arrays back
to base R objects (e.g. via as.matrix() or [[ extraction)
breaks the gradient tape and results in an error.
Examples
loss <- function(w, x, y) {
preds <- x %*% w
resids <- preds - y
sum(resids * resids) / length(y)
}
x <- mlx_matrix(1:8, 4, 2)
y <- mlx_matrix(c(1, 3, 2, 4), 4, 1)
w <- mlx_matrix(0, 2, 1)
#> Error: length(data) must equal nrow * ncol.
mlx_grad(loss, w, x, y)[[1]]
#> Error: object 'w' not found
loss <- function(w, x) sum((x %*% w) * (x %*% w))
x <- mlx_matrix(1:4, 2, 2)
w <- mlx_matrix(c(1, -1), 2, 1)
mlx_value_grad(loss, w, x)
#> $value
#> mlx array []
#> dtype: float32
#> device: gpu
#> values:
#> [1] 8
#>
#> $grads
#> $grads[[1]]
#> mlx array [2 x 1]
#> dtype: float32
#> device: gpu
#> values:
#> [,1]
#> [1,] -12
#> [2,] -28
#>
#> $grads[[2]]
#> mlx array [2 x 2]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2]
#> [1,] -4 4
#> [2,] -4 4
#>
#>