Automatic differentiation for MLX functions
mlx_grad.Rdmlx_grad() computes gradients of an R function that operates on mlx
arrays. The function must keep all differentiable computations in MLX
(e.g., via as_mlx() and MLX operators) and return an mlx object.
Arguments
- f
- An R function. Its arguments should be mlx objects, and its return value must be an mlx array (typically a scalar loss). 
- ...
- Arguments to pass to - f. They will be coerced to mlx if needed.
- argnums
- Indices (1-based) identifying which arguments to differentiate with respect to. Defaults to all arguments. 
- value
- Should the function value be returned alongside gradients? Set to - TRUEto receive a list with components- valueand- grads.
Value
When value = FALSE (default), a list of mlx arrays containing the
gradients in the same order as argnums. When value = TRUE, a list with
elements value (the function output as mlx) and grads.
Details
Keep the differentiated closure inside MLX operations. Coercing arrays back
to base R objects (such as as.matrix(), as.numeric(), or [[ extraction)
breaks the gradient tape and results in an error.
Examples
loss <- function(w, x, y) {
  preds <- x %*% w
  resids <- preds - y
  sum(resids * resids) / length(y)
}
x <- as_mlx(matrix(1:8, 4, 2))
y <- as_mlx(matrix(c(1, 3, 2, 4), 4, 1))
w <- as_mlx(matrix(0, 2, 1))
mlx_grad(loss, w, x, y)[[1]]
#> mlx array [2 x 1]
#>   dtype: float32
#>   device: gpu
#>   values:
#>       [,1]
#> [1,] -14.5
#> [2,] -34.5
loss <- function(w, x) sum((x %*% w) * (x %*% w))
x <- as_mlx(matrix(1:4, 2, 2))
w <- as_mlx(matrix(c(1, -1), 2, 1))
mlx_value_grad(loss, w, x)
#> $value
#> mlx array []
#>   dtype: float32
#>   device: gpu
#>   values:
#> [1] 8
#> 
#> $grads
#> $grads[[1]]
#> mlx array [2 x 1]
#>   dtype: float32
#>   device: gpu
#>   values:
#>      [,1]
#> [1,]  -12
#> [2,]  -28
#> 
#> $grads[[2]]
#> mlx array [2 x 2]
#>   dtype: float32
#>   device: gpu
#>   values:
#>      [,1] [,2]
#> [1,]   -4    4
#> [2,]   -4    4
#> 
#>