Linear Regression with MLX
linear-regression.RmdOverview
This vignette demonstrates linear regression using Rmlx, based on the MLX linear regression example. We’ll train a linear model using automatic differentiation and stochastic gradient descent (SGD) on GPU-accelerated arrays.
Problem Setup
We’ll create synthetic data for linear regression with high
dimensionality: - A random “true” weight vector w_star of
dimension 100 - A random design matrix X of 10,000 cases ×
100 features - Noisy labels
y = X @ w_star + small_noise
library(Rmlx)
#> 
#> Attaching package: 'Rmlx'
#> The following object is masked from 'package:stats':
#> 
#>     fft
#> The following objects are masked from 'package:base':
#> 
#>     chol2inv, colMeans, colSums, diag, outer, rowMeans, rowSums, svd
# Problem metadata
num_features <- 100
num_cases <- 10000
num_iters <- 1200          # iterations of SGD
learning_rate <- 0.01      # learning rate for SGD
# Set seed for reproducibility
set.seed(42)
# True parameters (what we're trying to learn)
w_star <- mlx_rand_normal(c(num_features, 1))
# Input examples (design matrix)
X <- mlx_rand_normal(c(num_cases, num_features))
# Noisy labels
eps <- 1e-2 * mlx_rand_normal(c(num_cases, 1))
y <- X %*% w_star + epsDefine the Loss Function
The mean squared error loss is a standard choice for regression:
# Define loss function
loss_fn <- function(w) {
  preds <- X %*% w
  residuals <- preds - y
  0.5 * mean(residuals * residuals)
}The loss measures how well our parameters w predict the
labels. Lower loss means better predictions.
Automatic Differentiation
Rmlx provides mlx_grad() to compute gradients via
automatic differentiation. This computes the gradient of the loss with
respect to our parameters:
# Get the gradient function
grad_fn <- function(w) {
  mlx_grad(loss_fn, w)[[1]]
}
train_sgd <- function(steps = num_iters, step_size = learning_rate, verbose = TRUE) {
  w <- 1e-2 * mlx_rand_normal(c(num_features, 1))
  for (i in seq_len(steps)) {
    grad <- grad_fn(w)
    w <- w - step_size * grad
    mlx_eval(w)
    if (verbose && i %% 1000 == 0) {
      cat("Iteration", i, "- Loss:", as.vector(loss_fn(w)), "\n")
    }
  }
  w
}Training Loop with SGD
We train by repeatedly computing gradients and updating parameters. In each iteration, we:
- Compute the gradient of loss with respect to w
- Update parameters using the gradient step
- Force evaluation to prevent the computation graph from growing unbounded
- Monitor progress by printing loss every 1000 iterations
w_sgd <- train_sgd()
#> Iteration 1000 - Loss: 5.110821e-05Method 2: Closed-form Regression via Matrix Algebra
Gradient descent is flexible, but linear regression also has a closed-form solution that can be obtained via the QR decomposition. Rather than forming explicitly, we factor with and solve the triangular system :
mlx_normal_eq <- function(X, y) {
  qr_res <- qr(X)
  q <- qr_res$Q
  r <- qr_res$R
  q_ty <- crossprod(q, y)
  mlx_solve_triangular(r, q_ty, upper = TRUE)
}
w_closed <- mlx_normal_eq(X, y)
mlx_eval(w_closed)
closed_error <- w_closed - w_star
closed_error_norm <- sqrt(sum(closed_error * closed_error))
cat("Closed-form ||w - w*|| =", as.vector(closed_error_norm), "\n")
#> Closed-form ||w - w*|| = 0.001063472Accelerating the Closed-form Solution with
mlx_compile()
The closed-form function mixes several MLX primitives. We can trace
and fuse those operations with mlx_compile(). The first
call incurs the tracing cost; subsequent calls reuse the compiled
graph.
compiled_normal_eq <- mlx_compile(mlx_normal_eq)
# Warm-up call performs tracing and compilation
mlx_eval(compiled_normal_eq(X, y))
# Re-use the compiled function
w_compiled <- compiled_normal_eq(X, y)
mlx_eval(w_compiled)
compiled_error <- w_compiled - w_star
compiled_error_norm <- sqrt(sum(compiled_error * compiled_error))
cat("Compiled closed-form ||w - w*|| =", as.vector(compiled_error_norm), "\n")
#> Compiled closed-form ||w - w*|| = 0.001063472Accuracy and Performance Comparison
To compare approaches we measure elapsed time over several repetitions and the resulting distance between each estimate and the true coefficients. We also add base R’s normal-equation implementation as a reference.
library(bench)
# Fit models once for accuracy measurements
w_sgd <- train_sgd(verbose = FALSE)
w_closed <- mlx_normal_eq(X, y)
compiled_normal_eq <- mlx_compile(mlx_normal_eq)
mlx_eval(compiled_normal_eq(X, y))
w_compiled <- compiled_normal_eq(X, y)
X_r <- as.matrix(X)
y_r <- as.matrix(y)
w_base <- matrix(lm.fit(X_r, y_r[, 1])$coefficients, ncol = 1)
# Accuracy comparisons
to_norm <- function(w_hat) {
  diff <- w_hat - as.matrix(w_star)
  sqrt(sum(diff * diff))
}
# Benchmark timings (compiled solution already warm)
timings <- bench::mark(
  sgd = {
    res <- train_sgd(verbose = FALSE)
    mlx_eval(res)
  },
  mlx_closed = {
    res <- mlx_normal_eq(X, y)
    mlx_eval(res)
  },
  mlx_closed_compiled = {
    res <- compiled_normal_eq(X, y)
    mlx_eval(res)
  },
  base_R = {
    lm.fit(X_r, y_r[, 1])$coefficients
  },
  iterations = 3,
  check = FALSE
) |>
  as.data.frame()
#> Warning: Some expressions had a GC in every iteration; so filtering is
#> disabled.
results <- data.frame(
  method = c("SGD", "MLX closed form", "MLX closed form (compiled)", "Base R"),
  median_time = timings$median,
  parameter_error = c(
    to_norm(as.matrix(w_sgd)),
    to_norm(as.matrix(w_closed)),
    to_norm(as.matrix(w_compiled)),
    to_norm(w_base)
  )
)
knitr::kable(results, digits = 4)| method | median_time | parameter_error | 
|---|---|---|
| SGD | 2.25s | 0.0011 | 
| MLX closed form | 22.16ms | 0.0011 | 
| MLX closed form (compiled) | 20.66ms | 0.0011 | 
| Base R | 86.51ms | 0.0011 | 
Device Selection
By default, computations run on GPU for speed. Switch to CPU if needed:
# Use CPU (useful for debugging)
mlx_default_device("cpu")
#> [1] "cpu"
# Or back to GPU
mlx_default_device("gpu")
#> [1] "gpu"