Sample from a multivariate normal distribution on mlx arrays
mlx_rand_multivariate_normal.RdSample from a multivariate normal distribution on mlx arrays
Usage
mlx_rand_multivariate_normal(
  dim,
  mean,
  cov,
  dtype = c("float32", "float64"),
  device = "cpu"
)Arguments
- dim
- Integer vector specifying the array shape/dimensions. 
- mean
- An mlx array or vector for the mean. 
- cov
- An mlx array or matrix for the covariance. 
- dtype
- Desired MLX dtype ("float32" or "float64"). 
- device
- Execution target: provide - "gpu",- "cpu", or an- mlx_streamcreated via- mlx_new_stream(). Defaults to the current- mlx_default_device().
Details
Samples are generated on the CPU: GPU execution is currently
unavailable because the covariance factorisation runs on the host. Supply a
CPU stream (via mlx_new_stream()) to integrate with asynchronous flows.