Skip to contents

mlx_flatten() mirrors mlx.core.flatten(), collapsing a contiguous range of axes into a single dimension.

Usage

mlx_flatten(x, start_axis = 1L, end_axis = -1L)

Arguments

x

An mlx array.

start_axis

First axis (1-indexed, negatives count from the end) in the flattened range.

end_axis

Last axis (1-indexed, negatives count from the end) in the flattened range.

Value

An mlx array with the selected axes collapsed.

See also

Examples

x <- as_mlx(array(1:12, dim = c(2, 3, 2)))
mlx_flatten(x)
#> mlx array [12]
#>   dtype: float32
#>   device: gpu
#>   values:
#>  [1]  1  7  3  9  5 11  2  8  4 10  6 12
mlx_flatten(x, start_axis = 2, end_axis = 3)
#> mlx array [2 x 6]
#>   dtype: float32
#>   device: gpu
#>   values:
#>      [,1] [,2] [,3] [,4] [,5] [,6]
#> [1,]    1    7    3    9    5   11
#> [2,]    2    8    4   10    6   12