Wraps mlx.core.gather()
so you can pull elements by axis. Provide one index per axis. Axes must
be positive integers (we don't allow negative indices, unlike Python).
Element-wise indexing
The output has the same shape as the indices (after broadcasting). Each element
[i, j, ...]of the output
is x[index_1[i, j, ...], index_2[i, j, ...], ...] from the corresponding
position of each index. See the examples below.
Examples
x <- mlx_matrix(1:9, 3, 3)
# Simple cartesian gather:
mlx_gather(x, list(1:2, 1:2))
#> mlx array [2]
#> dtype: float32
#> device: cpu
#> values:
#> [1] 1 5
# Element-wise pairs: grab a custom 2x2 grid of coordinates
row_idx <- matrix(c(1, 1,
2, 3), nrow = 2, byrow = TRUE)
col_idx <- matrix(c(1, 3,
2, 2), nrow = 2, byrow = TRUE)
# A 2x2 matrix where (e.g.) the bottom right element is x[3, 2]
mlx_gather(x, list(row_idx, col_idx))
#> mlx array [2 x 2]
#> dtype: float32
#> device: cpu
#> values:
#> [,1] [,2]
#> [1,] 1 7
#> [2,] 5 6