Skip to contents

Sample from a multivariate normal distribution on mlx arrays

Usage

mlx_rand_multivariate_normal(
  dim,
  mean,
  cov,
  dtype = c("float32", "float64"),
  device = "cpu"
)

Arguments

dim

Integer vector specifying array dimensions (shape).

mean

An mlx array or vector for the mean.

cov

An mlx array or matrix for the covariance.

dtype

Data type string. Supported types include:

  • Floating point: "float32", "float64"

  • Integer: "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"

  • Other: "bool", "complex64"

Not all functions support all types. See individual function documentation.

device

Execution target for APIs that expose a one-off device or stream override. Supply "gpu", "cpu", or an mlx_stream created via mlx_new_stream(). Ordinary array operations use the current mlx_device() instead.

Value

An mlx array with samples from the multivariate normal.

Details

GPU execution is currently unavailable because the covariance factorisation runs on the host.

Examples

mean <- as_mlx(c(0, 0))
cov <- as_mlx(matrix(c(1, 0, 0, 1), 2, 2))
samples <- with_device("cpu", mlx_rand_multivariate_normal(10, mean, cov))