Skip to contents

Sample from a multivariate normal distribution on mlx arrays

Usage

mlx_rand_multivariate_normal(
  dim,
  mean,
  cov,
  dtype = c("float32", "float64"),
  device = "cpu"
)

Arguments

dim

Integer vector specifying the array shape/dimensions.

mean

An mlx array or vector for the mean.

cov

An mlx array or matrix for the covariance.

dtype

Desired MLX dtype ("float32" or "float64").

device

Execution target: provide "gpu", "cpu", or an mlx_stream created via mlx_new_stream(). Defaults to the current mlx_default_device().

Value

An mlx array with samples from the multivariate normal.

Details

Samples are generated on the CPU: GPU execution is currently unavailable because the covariance factorisation runs on the host. Supply a CPU stream (via mlx_new_stream()) to integrate with asynchronous flows.

Examples

mean <- as_mlx(c(0, 0), device = "cpu")
cov <- as_mlx(matrix(c(1, 0, 0, 1), 2, 2), device = "cpu")
samples <- mlx_rand_multivariate_normal(c(100, 2), mean, cov, device = "cpu")