Triangular helpers for MLX arrays
mlx_tri.Rdmlx_tri() creates a lower-triangular mask (ones on and below a diagonal,
zeros elsewhere). mlx_tril() and mlx_triu() retain only the lower or
upper triangular part of an existing array, respectively.
Usage
mlx_tri(
  n,
  m = NULL,
  k = 0L,
  dtype = c("float32", "float64"),
  device = mlx_default_device()
)
mlx_tril(x, k = 0L)
mlx_triu(x, k = 0L)Arguments
- n
- Number of rows. 
- m
- Optional number of columns (defaults to - nfor square output).
- k
- Diagonal offset: - 0selects the main diagonal, positive values move to the upper diagonals, negative values to the lower diagonals.
- dtype
- MLX dtype to use ( - "float32"or- "float64").
- device
- Execution target: provide - "gpu",- "cpu", or an- mlx_streamcreated via- mlx_new_stream(). Defaults to the current- mlx_default_device().
- x
- Object coercible to - mlx.
Examples
mlx_tri(3)          # 3x3 lower-triangular mask
#> mlx array [3 x 3]
#>   dtype: float32
#>   device: gpu
#>   values:
#>      [,1] [,2] [,3]
#> [1,]    1    0    0
#> [2,]    1    1    0
#> [3,]    1    1    1
mlx_tril(diag(3) + 2)  # keep lower part of a matrix
#> mlx array [3 x 3]
#>   dtype: float32
#>   device: gpu
#>   values:
#>      [,1] [,2] [,3]
#> [1,]    3    0    0
#> [2,]    2    3    0
#> [3,]    2    2    3