Skip to contents

Sample from a truncated normal distribution on mlx arrays

Usage

mlx_rand_truncated_normal(
  lower,
  upper,
  dim,
  dtype = c("float32", "float64"),
  device = mlx_default_device()
)

Arguments

lower

Lower bound of the truncated normal.

upper

Upper bound of the truncated normal.

dim

Integer vector specifying the array shape/dimensions.

dtype

Desired MLX dtype ("float32" or "float64").

device

Execution target: provide "gpu", "cpu", or an mlx_stream created via mlx_new_stream(). Defaults to the current mlx_default_device().

Value

An mlx array with truncated normally distributed entries.

Examples

samples <- mlx_rand_truncated_normal(-1, 1, c(5, 5))