mlx_flatten() mirrors mlx.core.flatten(),
collapsing a contiguous range of axes into a single dimension.
Examples
x <- mlx_array(1:12, dim = c(2, 3, 2))
mlx_flatten(x)
#> mlx array [12]
#> dtype: float32
#> device: gpu
#> values:
#> [1] 1 7 3 9 5 11 2 8 4 10 6 12
mlx_flatten(x, start_axis = 2, end_axis = 3)
#> mlx array [2 x 6]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2] [,3] [,4] [,5] [,6]
#> [1,] 1 7 3 9 5 11
#> [2,] 2 8 4 10 6 12