Sample from a multivariate normal distribution on mlx arrays
Source:R/random.R
mlx_rand_multivariate_normal.RdSample from a multivariate normal distribution on mlx arrays
Usage
mlx_rand_multivariate_normal(
dim,
mean,
cov,
dtype = c("float32", "float64"),
device = "cpu"
)Arguments
- dim
Integer vector specifying array dimensions (shape).
- mean
An mlx array or vector for the mean.
- cov
An mlx array or matrix for the covariance.
- dtype
Desired MLX dtype ("float32" or "float64").
- device
Execution target: supply
"gpu","cpu", or anmlx_streamcreated viamlx_new_stream(). Defaults to the currentmlx_default_device()unless noted otherwise (helpers that act on an existing array typically reuse that array's device or stream).
Details
Samples are generated on the CPU: GPU execution is currently
unavailable because the covariance factorisation runs on the host. Supply a
CPU stream (via mlx_new_stream()) to integrate with asynchronous flows.