Skip to contents

Fit generalized linear models using iterative reweighted least squares (IRLS) with MLX providing the heavy lifting for weighted least squares solves.

Usage

mlxs_glm(
  formula,
  family = mlxs_gaussian(),
  data,
  subset,
  weights,
  na.action,
  start = NULL,
  control = list(),
  ...
)

Arguments

formula

an object of class "formula" (or one that can be coerced to that class): a symbolic description of the model to be fitted. The details of model specification are given under ‘Details’.

family

A mlxs family object (e.g., mlxs_gaussian(), mlxs_binomial(), mlxs_poisson()).

data

an optional data frame, list or environment (or object coercible by as.data.frame to a data frame) containing the variables in the model. If not found in data, the variables are taken from environment(formula), typically the environment from which glm is called.

subset

an optional vector specifying a subset of observations to be used in the fitting process.

weights

an optional vector of ‘prior weights’ to be used in the fitting process. Should be NULL or a numeric vector.

na.action

a function which indicates what should happen when the data contain NAs. The default is set by the na.action setting of options, and is na.fail if that is unset. The ‘factory-fresh’ default is na.omit. Another possible value is NULL, no action. Value na.exclude can be useful.

start

starting values for the parameters in the linear predictor.

control

Optional list of control parameters passed to stats::glm.control().

...

For glm: arguments to be used to form the default control argument if it is not supplied directly.

For weights: further arguments passed to or from other methods.

Value

An object of class c("mlxs_glm", "mlxs_model") containing elements similar to the result of stats::glm(). Computations use single-precision MLX arrays, so results typically agree with stats::glm() to around 1e-6 unless a tighter tolerance is supplied via control.

Examples

fit <- mlxs_glm(mpg ~ cyl + disp, family = mlxs_gaussian(), data = mtcars)
coef(fit)
#> mlx array [3 x 1]
#>   dtype: float32
#>   device: gpu
#>   values:
#>             [,1]
#> [1,] 34.66099167
#> [2,] -1.58727658
#> [3,] -0.02058364
#> attr(,"coef_names")
#> [1] "(Intercept)" "cyl"         "disp"