mlx_tri() creates a lower-triangular mask (ones on and below a diagonal,
zeros elsewhere). mlx_tril() and mlx_triu() retain only the lower or
upper triangular part of an existing array, respectively.
Usage
mlx_tri(
n,
m = NULL,
k = 0L,
dtype = c("float32", "float64"),
device = mlx_default_device()
)
mlx_tril(x, k = 0L)
mlx_triu(x, k = 0L)Arguments
- n
Number of rows.
- m
Optional number of columns (defaults to
nfor square output).- k
Diagonal offset:
0selects the main diagonal, positive values move to the upper diagonals, negative values to the lower diagonals.- dtype
MLX dtype to use (
"float32"or"float64").- device
Execution target: supply
"gpu","cpu", or anmlx_streamcreated viamlx_new_stream(). Defaults to the currentmlx_default_device()unless noted otherwise (helpers that act on an existing array typically reuse that array's device or stream).- x
Object coercible to
mlx.
Examples
mlx_tri(3) # 3x3 lower-triangular mask
#> mlx array [3 x 3]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2] [,3]
#> [1,] 1 0 0
#> [2,] 1 1 0
#> [3,] 1 1 1
mlx_tril(diag(3) + 2) # keep lower part of a matrix
#> mlx array [3 x 3]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2] [,3]
#> [1,] 3 0 0
#> [2,] 2 3 0
#> [3,] 2 2 3