Skip to contents

Core Tensor API

as_mlx()
Create MLX array from R object
is.mlx()
Test if object is an MLX array
Rmlx-package Rmlx
Rmlx: R Interface to Apple's MLX Arrays
mlx-methods
Base R generics with mlx methods
mlx_dim()
Get dimensions helper
mlx_dtype()
Get data type helper
mlx_eval()
Force evaluation of lazy MLX operations
`[`(<mlx>) `[<-`(<mlx>)
Subset MLX array
dim(<mlx>)
Get dimensions of MLX array
`dim<-`(<mlx>)
Set dimensions of MLX array
length(<mlx>)
Get length of MLX array
print(<mlx>)
Print MLX array
str(<mlx>)
Object structure for MLX array
t(<mlx>)
Transpose of MLX matrix
as.matrix(<mlx>)
Convert MLX array to R matrix/array
as.array(<mlx>)
Convert MLX array to R array
as.vector(<mlx>)
Convert MLX array to R vector

Device & Execution

mlx_default_device()
Get or set default MLX device
with_default_device()
Temporarily set the default MLX device
mlx_new_stream() mlx_default_stream()
MLX streams for asynchronous execution
mlx_set_default_stream()
Set the default MLX stream
mlx_synchronize()
Synchronize MLX execution
mlx_forward()
Forward pass utility
mlx_grad() mlx_value_grad()
Automatic differentiation for MLX functions
mlx_stop_gradient()
Stop gradient propagation through an mlx array
mlx_compile()
Compile an MLX Function for Optimized Execution
mlx_disable_compile() mlx_enable_compile()
Control Global Compilation Behavior

Creation & Randomness

mlx_zeros()
Create arrays of zeros on MLX devices
mlx_ones()
Create arrays of ones on MLX devices
mlx_zeros_like()
Zeros shaped like an existing mlx array
mlx_ones_like()
Ones shaped like an existing mlx array
mlx_full()
Fill an mlx array with a constant value
mlx_eye()
Identity-like matrices on MLX devices
mlx_identity()
Identity matrices on MLX devices
mlx_arange()
Numerical ranges on MLX devices
mlx_linspace()
Evenly spaced ranges on MLX devices
mlx_rand_bernoulli()
Sample Bernoulli random variables on mlx arrays
mlx_rand_categorical()
Sample from a categorical distribution on mlx arrays
mlx_rand_gumbel()
Sample from the Gumbel distribution on mlx arrays
mlx_rand_laplace()
Sample from the Laplace distribution on mlx arrays
mlx_rand_multivariate_normal()
Sample from a multivariate normal distribution on mlx arrays
mlx_rand_normal()
Sample from a normal distribution on mlx arrays
mlx_rand_permutation()
Generate random permutations on mlx arrays
mlx_rand_randint()
Sample random integers on mlx arrays
mlx_rand_truncated_normal()
Sample from a truncated normal distribution on mlx arrays
mlx_rand_uniform()
Sample from a uniform distribution on mlx arrays
mlx_key() mlx_key_split()
Construct MLX random number generator keys
mlx_key_bits()
Generate raw random bits on MLX arrays

Shape & Indexing

mlx_reshape()
Reshape an mlx array
mlx_stack()
Stack mlx arrays along a new axis
mlx_squeeze()
Remove singleton dimensions
mlx_expand_dims()
Insert singleton dimensions
mlx_repeat()
Repeat array elements
mlx_tile()
Tile an array
mlx_pad() mlx_split()
Pad or split mlx arrays
mlx_roll()
Roll array elements
mlx_moveaxis() aperm(<mlx>)
Reorder mlx array axes
mlx_contiguous()
Ensure contiguous memory layout
mlx_flatten()
Flatten axes of an mlx array
mlx_swapaxes()
Swap two axes of an mlx array
mlx_unflatten()
Unflatten an axis into multiple axes
mlx_meshgrid()
Construct coordinate arrays from input vectors
mlx_broadcast_to()
Broadcast an array to a new shape
mlx_broadcast_arrays()
Broadcast multiple arrays to a shared shape
mlx_where()
Elementwise conditional selection
mlx_tri() mlx_tril() mlx_triu()
Triangular helpers for MLX arrays
mlx_slice_update()
Update a slice of an mlx array
mlx_gather()
Gather elements from an mlx array
abind()
Bind mlx arrays along an axis
rbind(<mlx>)
Row-bind mlx arrays
cbind(<mlx>)
Column-bind mlx arrays

Ordering & Selection

mlx_sort() mlx_argsort()
Sort and argsort for mlx arrays
mlx_topk() mlx_partition() mlx_argpartition()
Top-k selection and partitioning on mlx arrays
mlx_argmax() mlx_argmin()
Argmax and argmin on mlx arrays

Math & Reductions

Math(<mlx>)
Math operations for MLX arrays
Ops(<mlx>)
Arithmetic and comparison operators for MLX arrays
mlx_sum() mlx_prod() mlx_all() mlx_any() mlx_mean() mlx_var() mlx_std()
Reduce mlx arrays
mean(<mlx>)
Mean of MLX array elements
mlx_cumsum() mlx_cumprod()
Cumulative sum and product
mlx_clip()
Clip mlx array values into a range
mlx_maximum()
Elementwise maximum of two mlx arrays
mlx_minimum()
Elementwise minimum of two mlx arrays
mlx_hadamard_transform()
Hadamard transform for MLX arrays
mlx_softmax()
Softmax for mlx arrays
mlx_logsumexp()
Log-sum-exp reduction for mlx arrays
mlx_logcumsumexp()
Log cumulative sum exponential for mlx arrays
mlx_isnan() mlx_isinf() mlx_isfinite()
Elementwise NaN and infinity predicates
mlx_isposinf() mlx_isneginf()
Detect signed infinities in mlx arrays
mlx_nan_to_num()
Replace NaN and infinite values with finite numbers
mlx_real() mlx_imag() mlx_conjugate()
Complex-valued helpers for mlx arrays
mlx_degrees() mlx_radians()
Convert between radians and degrees
mlx_isclose()
Element-wise approximate equality
mlx_allclose()
Test if all elements of two arrays are close
all.equal(<mlx>)
Test if two MLX arrays are (nearly) equal
colSums()
Column sums for mlx arrays
rowSums()
Row sums for mlx arrays
colMeans()
Column means for mlx arrays
rowMeans()
Row means for mlx arrays
fft()
Fast Fourier Transform
mlx_fft() mlx_fft2() mlx_fftn()
Fast Fourier transforms for MLX arrays

Linear Algebra

`%*%`(<mlx>)
Matrix multiplication for MLX arrays
mlx_addmm()
Fused matrix multiply and add for MLX arrays
crossprod(<mlx>)
Cross product
tcrossprod(<mlx>)
Transposed cross product
outer()
Outer product of two vectors
diag()
Diagonal matrix extraction and construction
chol(<mlx>)
Cholesky decomposition for mlx arrays
chol2inv()
Inverse from Cholesky decomposition
kronecker() kronecker.default()
Kronecker product dispatcher
qr(<mlx>)
QR decomposition for mlx arrays
svd()
Singular value decomposition
svd(<mlx>)
Singular value decomposition for mlx arrays
solve(<mlx>)
Solve a system of linear equations
pinv()
Moore-Penrose pseudoinverse for MLX arrays
mlx_kron()
Kronecker product for mlx arrays
mlx_inv()
Compute matrix inverse
mlx_tri_inv()
Compute triangular matrix inverse
mlx_cholesky_inv()
Compute matrix inverse via Cholesky decomposition
mlx_lu()
LU factorization
mlx_norm()
Matrix and vector norms for mlx arrays
mlx_solve_triangular()
Solve triangular systems with mlx arrays
mlx_trace()
Matrix trace for mlx arrays
diag(<mlx>) mlx_diagonal()
Extract diagonal or construct diagonal matrix for mlx arrays
mlx_eig()
Eigen decomposition for mlx arrays
mlx_eigh()
Eigen decomposition of Hermitian mlx arrays
mlx_eigvals()
Eigenvalues of mlx arrays
mlx_eigvalsh()
Eigenvalues of Hermitian mlx arrays
mlx_cross()
Vector cross product with mlx arrays

Input & Output

mlx_save()
Save an MLX array to disk
mlx_load()
Load an MLX array from disk
mlx_save_safetensors()
Save MLX arrays to the safetensors format
mlx_load_safetensors()
Load MLX arrays from the safetensors format
mlx_save_gguf()
Save MLX arrays to the GGUF format
mlx_load_gguf()
Load MLX tensors from the GGUF format

Neural Network Layers

mlx_linear()
Create a learnable linear transformation
mlx_sequential()
Compose modules sequentially
mlx_set_training()
Toggle training mode for MLX modules
mlx_embedding()
Embedding layer
mlx_conv1d()
1D Convolution
mlx_conv2d()
2D Convolution
mlx_conv3d()
3D Convolution
mlx_conv_transpose1d()
1D Transposed Convolution
mlx_conv_transpose2d()
2D Transposed Convolution
mlx_conv_transpose3d()
3D Transposed Convolution
mlx_quantize()
Quantize a Matrix
mlx_dequantize()
Dequantize a Matrix
mlx_quantized_matmul()
Quantized Matrix Multiplication
mlx_gather_qmm()
Gather-based Quantized Matrix Multiplication

Activation Functions

mlx_relu()
Rectified linear activation module
mlx_gelu()
GELU activation
mlx_sigmoid()
Sigmoid activation
mlx_tanh()
Tanh activation
mlx_silu()
SiLU (Swish) activation
mlx_leaky_relu()
Leaky ReLU activation
mlx_softmax_layer()
Softmax activation

Regularization & Normalization

mlx_dropout()
Dropout layer
mlx_layer_norm()
Layer normalization
mlx_batch_norm()
Batch normalization

Loss Functions

mlx_mse_loss()
Mean squared error loss
mlx_l1_loss()
L1 loss (Mean Absolute Error)
mlx_binary_cross_entropy()
Binary cross-entropy loss
mlx_cross_entropy()
Cross-entropy loss

Training Utilities

mlx_parameters()
Collect parameters from modules
mlx_param_values()
Retrieve parameter arrays
mlx_param_set_values()
Assign arrays back to parameters
mlx_optimizer_sgd()
Stochastic gradient descent optimizer
mlx_train_step()
Single training step helper