Wraps mlx.core.gather()
so you can pull elements by axis. Provide one index per axis. Axes must
be positive integers (we don't allow negative indices, unlike Python).
Element-wise indexing
The output has the same shape as the indices. Each element of the output
is x[index_1, index_2, ...] from the corresponding position of each
index. See the examples below.
Examples
x <- mlx_matrix(1:9, 3, 3)
# Simple cartesian gather:
mlx_gather(x, list(1:2, 1:2), axes = 1:2)
#> mlx array [2]
#> dtype: float32
#> device: gpu
#> values:
#> [1] 1 5
# Element-wise pairs: grab a custom 2x2 grid of coordinates
row_idx <- matrix(c(1, 1,
2, 3), nrow = 2, byrow = TRUE)
col_idx <- matrix(c(1, 3,
2, 2), nrow = 2, byrow = TRUE)
mlx_gather(x, list(row_idx, col_idx), axes = c(1L, 2L))
#> mlx array [2 x 2]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2]
#> [1,] 1 7
#> [2,] 5 6