mlx is a python/C/C++ API for GPUs on modern Macs. Rmlx is an R interface to mlx. The core idea behind Rmlx is simple: statistics uses a lot of matrix operations, the GPU can do matrix operations fast, so we should let R have access to the GPU.
RmlxStats is a
showcase for Rmlx: it rewrites some well-known R model fitting functions
using mlx as a backend. These include mlxs_lm(),
mlxs_glm(), mlxs_prcomp() for principal
components analysis, and mlxs_glmnet() which is inspired by
glmnet() from the glmnet package. The aim of RmlxStats is
to learn what Rmlx can do, and to provide fast versions of basic
statistics.
mlx has some nice advantages: in particular, you can switch between
GPU and CPU operations without moving data around, which is helpful for
many statistics operations. Still, it isn’t a perfect solution. Not all
operations work on the GPU yet: solve() is CPU-only.
Despite this, mlx can be pretty fast:
# On my machine
> system.time(lm(arr_delay ~ dep_delay + factor(paste(month,day)),
data = nycflights13::flights))
user system elapsed
31.769 0.544 32.764
> system.time({
fit <- mlxs_lm(arr_delay ~ dep_delay + factor(paste(month,day)),
data = nycflights13::flights)
Rmlx::mlx_eval(fit$coefficients)
})
user system elapsed
4.274 0.739 3.351
Or even, very fast:
> mat <- matrix(rnorm(1e7), 1e4)
> system.time({pr <- prcomp(mat); summary(mat)})
user system elapsed
26.785 0.151 26.943
> system.time({pr <- mlxs_prcomp(mat); summary(mat)})
user system elapsed
0.437 0.063 0.548
Those snippets show off another feature of mlx: evaluation is lazy, so work isn’t done until you evaluate the final result.
Here’s a subtler problem: GPUs work in 32 bits. mlx has support for float64 data types, but you can only do operations on the CPU with them. So, to work on the GPU, we are in float32 land. That limits the precision of our results:
library(RmlxStats)
base_fit <- lm(mpg ~ gear, mtcars)
mlxs_fit <- mlxs_lm(mpg ~ gear, mtcars)
max(abs(coef(mlxs_fit) - coef(base_fit)))
## [1] 3.217061e-06
Does that matter? After all, differences at the 6th decimal place will often be drowned out by sampling error.
params <- expand.grid(n = 10^(5:7), r_sq = c(0.5, 0.9, 0.99))
results <- params |>
split(seq_len(nrow(params))) |>
lapply(function (param) {
n <- param$n
r_sq <- param$r_sq
x <- rnorm(n)
a <- sqrt((1 - r_sq)/r_sq)
y <- x + a * rnorm(n)
base_fit <- lm(y ~ x)
mlxs_fit <- mlxs_lm(y ~ x)
coef_diff <- max(abs(coef(base_fit) - coef(mlxs_fit)))
se <- coef(summary(base_fit))["x", "Std. Error"]
data.frame(n, r_sq, coef_diff, se)
})
results <- do.call(rbind, results)
results$n <- factor(results$n)
results$`R squared` <- factor(results$r_sq)
library(ggplot2)
ggplot(results,
aes(coef_diff, se, colour = n, group = n)) +
geom_point(aes(shape = `R squared`)) +
geom_line() +
geom_function(aes(group = 1), fun = \(x) 100*x, linetype="dotted", colour = "black") +
scale_x_log10() +
scale_y_log10() +
theme_minimal() +
labs(
title = "Computational error versus sampling error for mlxs_lm()",
subtitle = "Dotted line marks where computational error is 1% of a s.e.",
x = "Computational difference",
y = "Standard error"
)
The graph above plots sampling standard errors against computational
differences between stats::lm() and
RmlxStats::mlxs_lm(). (I’m taking stats::lm()
as ground truth.) Until your N is 10 million and your R-squared is high,
sampling error dominates by far.
Fine, but on the other hand…
Since recent updates, Rmlx supports float64 types. So, one possibility is to do early work in float32 on the GPU, then finish in float64 on the CPU. This might give us speed and precision.
A natural place to try this was mlxs_glm(). General
linear models use iterations of least squares to find their solution, so
we can iterate on the GPU, then the CPU after we hit a tolerance
threshold.
Rmlx and RmlxStats are both vibe-coded. The process has taught me about the benefits and dangers of AI coding: the AI is a competent programmer and extremely knowledgeable about statistics, but it will also regularly do things that are profoundly silly, and it is quite good at hiding them. (For example, earlier versions of Rmlx were silently ignoring the user’s choice of device.) This is particularly risky when mistakes can lead not to an error message, but to a silently wrong statistical result! So, before working on the code itself, I wanted to build a test harness.
I coded a series of fuzz tests, working with the AI. Its specialist knowledge helped me do better than I could have done on my own: I would have probably just used large datasets of random numbers, whereas it knew to focus in on “problematic” or interesting sets of variables. It also pointed me to the NIST statistical datasets, which are useful small test cases where correct answers have been calculated by hand.
One problem with statistical fuzz tests is knowing what the right
answer is! Mostly, I took the approach of trusting the base R
implementation (or glmnet::glmnet()) and taking differences
from it as errors. But in some cases we can do better - for example, by
calculating performance on a defined objective function, or testing
model prediction accuracy against an “oracle” based on the true data
generating process.
Once I was happy with the fuzz tests, I stored results in a csv file populated by a GitHub action. This lets me check how results improve over time or with new pull requests. There’s an Rmarkdown vignette to show the results - it’s a little crude right now, but helpful to show when something gets better.
We then turned to mlxs_glm(). We added an
epsilon_f64 control
variable which lets the user tune when to move to float64 on the
GPU.
Does it work? Here I plot speed against accuracy for different values
of epsilon and epsilon_f64.
library(RmlxStats)
set.seed(123L)
n <- 1e5
x <- matrix(rnorm(n * 10 ), n)
y <- rnorm(n) + x %*% rep(0.1, 10)
x <- as.data.frame(x)
x$y <- y
x$y_lgl <- x$y > 0
fml <- reformulate(colnames(x)[1:10], response = "y_lgl")
base_fit <- glm(fml, data = x, family = "binomial")
base_coefs <- coef(base_fit)
eps <- 10^-(6:12)
eps_f64 <- 10^-(5:7)
params <- expand.grid(eps = eps, eps_f64 = eps_f64)
params <- params[params$eps < params$eps_f64,]
results <- params |>
split(seq_len(nrow(params))) |>
lapply(function (param) {
nreps <- 25
time <- system.time({
for (rep in 1:nreps) {
fit <- mlxs_glm(fml, data = x, family = "binomial",
control = list(epsilon_f64 = param$eps_f64,
epsilon = param$eps))
coefs <- coef(fit)
}
})
elapsed <- time[["elapsed"]]/nreps
coef_diff <- max(abs(coefs - base_coefs))
data.frame(epsilon = param$eps, epsilon_f64 = param$eps_f64,
elapsed = elapsed, coef_diff = coef_diff)
})
results <- do.call(rbind, results)
library(forcats)
results |>
transform(epsilon_f64 = factor(epsilon_f64), epsilon = factor(epsilon)) |>
transform(epsilon = fct_rev(epsilon)) |>
ggplot(aes(epsilon, coef_diff,
size = elapsed, colour = epsilon_f64)) +
geom_point() +
scale_color_viridis_d(end = 0.9, option = "A") +
scale_y_log10() +
theme_minimal() +
labs(
title = "Coefficient error in mlxs_glm()",
y = "Coefficient error",
size = "Time elapsed"
)
There’s a clear jump in accuracy at epsilon = 1e-7, but
this only happens when epsilon_f64 is low enough.
Otherwise, conversion to float64 never kicks in because we simply never
hit the threshold. We never get closer than 3e-9 to the base result,
even at very low epsilons. This may simply be due to the different
algorithms being used.
You can also see that time hasn’t taken too huge a hit from increasing the accuracy. The next graph shows this in more detail:
library(santoku)
results |>
transform(
epsilon_f64 = factor(epsilon_f64),
epsilon = factor(epsilon),
Error = santoku::chop(coef_diff, 10^-(9:5),
labels = lbl_glue(label = "< {r}"))) |>
transform(epsilon = fct_rev(epsilon)) |>
ggplot(aes(epsilon, elapsed, colour = epsilon_f64, group = epsilon_f64)) +
geom_line() +
geom_point(aes(shape = Error), fill = "white") +
scale_y_log10() +
scale_shape_manual(values = c("< 1e-08" = "circle",
"< 1e-06" = "circle filled")) +
theme_minimal() +
labs(
title = "Speed in mlxs_glm() with float64 operations",
y = "Time elapsed (log scale)"
)
Moving to float64 too early slows you, though not by much; moving to
it never (1e-7) is faster, but your results stay at float32 levels of
accuracy. The sweet spot is epsilon_f64 = 1e-6, and that is
the current default. (These are quite rough measurements for a blog post
- the benchmarkand
fuzz
test vignettes have more details.)
It’s worth sharing a vibe-coding gone wrong anecdote. At some point,
Codex asked me if we should check for rank-deficiency in linear fits. I
thought about it and said yes, and it implemented a function
.mlxs_check_full_rank(x, ...), which I accepted without
checking too hard.
Oops. The function recomputed the QR decomposition for the design
matrix x, meaning we were now doing the most
computationally expensive part of linear regression, twice. I didn’t
notice the slowdown in all my benchmarks until later, and an earlier
version of this article said “mlxs_lm() is still slower
than stats::lm()”. No it’s not!
Fortunately, Codex also saved me when I asked it to find the source of the slowdown. It bisected commits and found the problem, leading me to fix it, and then to ask:
You have genius-level statistical and computing knowledge. Is there some instruction you could be given, that would stop you doing very stupid things?
I was encouraged by this, and tried similar tricks for
mlxs_lm() and the other package functions. None of these
really worked; either the limits on accuracy were not just down to the
GPU, or my approach was wrong. So for now, most of RmlxStats remains in
float32.
Still, the testing infrastructure I’ve set up should help check any future work, and progress in the mlx framework should open up new speedups. I’ve also got more comfortable working with llms. It’s about finding the balance between speed and accuracy - kind of like the statistics, in fact.