Skip to contents

Overview

This vignette demonstrates linear regression using Rmlx, based on the MLX linear regression example. We’ll train a linear model using automatic differentiation and stochastic gradient descent (SGD) on GPU-accelerated arrays.

Problem Setup

We’ll create synthetic data for linear regression with high dimensionality: - A random “true” weight vector w_star of dimension 100 - A random design matrix X of 10,000 cases × 100 features - Noisy labels y = X @ w_star + small_noise

library(Rmlx)
#> 
#> Attaching package: 'Rmlx'
#> The following object is masked from 'package:stats':
#> 
#>     fft
#> The following objects are masked from 'package:base':
#> 
#>     chol2inv, colMeans, colSums, diag, outer, rowMeans, rowSums, svd

# Problem metadata
num_features <- 100
num_cases <- 10000
num_iters <- 1200          # iterations of SGD
learning_rate <- 0.01      # learning rate for SGD

# Set seed for reproducibility
set.seed(42)

# True parameters (what we're trying to learn)
w_star <- mlx_rand_normal(c(num_features, 1))

# Input examples (design matrix)
X <- mlx_rand_normal(c(num_cases, num_features))

# Noisy labels
eps <- 1e-2 * mlx_rand_normal(c(num_cases, 1))
y <- X %*% w_star + eps

Define the Loss Function

The mean squared error loss is a standard choice for regression:

# Define loss function
loss_fn <- function(w) {
  preds <- X %*% w
  residuals <- preds - y
  0.5 * mean(residuals * residuals)
}

The loss measures how well our parameters w predict the labels. Lower loss means better predictions.

Automatic Differentiation

Rmlx provides mlx_grad() to compute gradients via automatic differentiation. This computes the gradient of the loss with respect to our parameters:

# Get the gradient function
grad_fn <- function(w) {
  mlx_grad(loss_fn, w)[[1]]
}

train_sgd <- function(steps = num_iters, step_size = learning_rate, verbose = TRUE) {
  w <- 1e-2 * mlx_rand_normal(c(num_features, 1))
  for (i in seq_len(steps)) {
    grad <- grad_fn(w)
    w <- w - step_size * grad
    mlx_eval(w)
    if (verbose && i %% 1000 == 0) {
      cat("Iteration", i, "- Loss:", as.vector(loss_fn(w)), "\n")
    }
  }
  w
}

Training Loop with SGD

We train by repeatedly computing gradients and updating parameters. In each iteration, we:

  1. Compute the gradient of loss with respect to w
  2. Update parameters using the gradient step
  3. Force evaluation to prevent the computation graph from growing unbounded
  4. Monitor progress by printing loss every 1000 iterations
w_sgd <- train_sgd()
#> Iteration 1000 - Loss: 5.110821e-05

Method 2: Closed-form Regression via Matrix Algebra

Gradient descent is flexible, but linear regression also has a closed-form solution that can be obtained via the QR decomposition. Rather than forming XXX^\top X explicitly, we factor X=QRX = QR with QQ=IQ^\top Q = I and solve the triangular system Rw=QyRw = Q^\top y:

mlx_normal_eq <- function(X, y) {
  qr_res <- qr(X)
  q <- qr_res$Q
  r <- qr_res$R
  q_ty <- crossprod(q, y)
  mlx_solve_triangular(r, q_ty, upper = TRUE)
}

w_closed <- mlx_normal_eq(X, y)
mlx_eval(w_closed)

closed_error <- w_closed - w_star
closed_error_norm <- sqrt(sum(closed_error * closed_error))
cat("Closed-form ||w - w*|| =", as.vector(closed_error_norm), "\n")
#> Closed-form ||w - w*|| = 0.001063472

Accelerating the Closed-form Solution with mlx_compile()

The closed-form function mixes several MLX primitives. We can trace and fuse those operations with mlx_compile(). The first call incurs the tracing cost; subsequent calls reuse the compiled graph.

compiled_normal_eq <- mlx_compile(mlx_normal_eq)

# Warm-up call performs tracing and compilation
mlx_eval(compiled_normal_eq(X, y))

# Re-use the compiled function
w_compiled <- compiled_normal_eq(X, y)
mlx_eval(w_compiled)

compiled_error <- w_compiled - w_star
compiled_error_norm <- sqrt(sum(compiled_error * compiled_error))
cat("Compiled closed-form ||w - w*|| =", as.vector(compiled_error_norm), "\n")
#> Compiled closed-form ||w - w*|| = 0.001063472

Accuracy and Performance Comparison

To compare approaches we measure elapsed time over several repetitions and the resulting distance between each estimate and the true coefficients. We also add base R’s normal-equation implementation as a reference.

library(bench)

# Fit models once for accuracy measurements
w_sgd <- train_sgd(verbose = FALSE)
w_closed <- mlx_normal_eq(X, y)
compiled_normal_eq <- mlx_compile(mlx_normal_eq)
mlx_eval(compiled_normal_eq(X, y))
w_compiled <- compiled_normal_eq(X, y)
X_r <- as.matrix(X)
y_r <- as.matrix(y)
w_base <- matrix(lm.fit(X_r, y_r[, 1])$coefficients, ncol = 1)

# Accuracy comparisons
to_norm <- function(w_hat) {
  diff <- w_hat - as.matrix(w_star)
  sqrt(sum(diff * diff))
}

# Benchmark timings (compiled solution already warm)
timings <- bench::mark(
  sgd = {
    res <- train_sgd(verbose = FALSE)
    mlx_eval(res)
  },
  mlx_closed = {
    res <- mlx_normal_eq(X, y)
    mlx_eval(res)
  },
  mlx_closed_compiled = {
    res <- compiled_normal_eq(X, y)
    mlx_eval(res)
  },
  base_R = {
    lm.fit(X_r, y_r[, 1])$coefficients
  },
  iterations = 3,
  check = FALSE
) |>
  as.data.frame()
#> Warning: Some expressions had a GC in every iteration; so filtering is
#> disabled.

results <- data.frame(
  method = c("SGD", "MLX closed form", "MLX closed form (compiled)", "Base R"),
  median_time = timings$median,
  parameter_error = c(
    to_norm(as.matrix(w_sgd)),
    to_norm(as.matrix(w_closed)),
    to_norm(as.matrix(w_compiled)),
    to_norm(w_base)
  )
)
knitr::kable(results, digits = 4)
method median_time parameter_error
SGD 2.25s 0.0011
MLX closed form 22.16ms 0.0011
MLX closed form (compiled) 20.66ms 0.0011
Base R 86.51ms 0.0011

Device Selection

By default, computations run on GPU for speed. Switch to CPU if needed:

# Use CPU (useful for debugging)
mlx_default_device("cpu")
#> [1] "cpu"

# Or back to GPU
mlx_default_device("gpu")
#> [1] "gpu"