Sample from a multivariate normal distribution on mlx arrays
Source:R/random.R
mlx_rand_multivariate_normal.RdSample from a multivariate normal distribution on mlx arrays
Usage
mlx_rand_multivariate_normal(
dim,
mean,
cov,
dtype = c("float32", "float64"),
device = "cpu"
)Arguments
- dim
Integer vector specifying array dimensions (shape).
- mean
An mlx array or vector for the mean.
- cov
An mlx array or matrix for the covariance.
- dtype
Data type string. Supported types include:
Floating point:
"float32","float64"Integer:
"int8","int16","int32","int64","uint8","uint16","uint32","uint64"Other:
"bool","complex64"
Not all functions support all types. See individual function documentation.
- device
Execution target: supply
"gpu","cpu", or anmlx_streamcreated viamlx_new_stream(). Defaults to the currentmlx_default_device()unless noted otherwise (helpers that act on an existing array typically reuse that array's device or stream).
Details
Samples are generated on the CPU: GPU execution is currently
unavailable because the covariance factorisation runs on the host. Supply a
CPU stream (via mlx_new_stream()) to integrate with asynchronous flows.