The reverse of flattening: expands a single axis into multiple axes with the given shape.
Examples
# Flatten and unflatten
x <- mlx_array(1:24, c(2, 3, 4))
x_flat <- mlx_reshape(x, c(2, 12)) # flatten last two dims
mlx_unflatten(x_flat, axis = 2, shape = c(3, 4)) # restore original shape
#> mlx array [2 x 3 x 4]
#> dtype: float32
#> device: gpu
#> (24 elements, not shown)