Computes beta * input + alpha * (mat1 %*% mat2) in a single MLX kernel.
All operands are promoted to a common dtype/device prior to evaluation.
Examples
input <- as_mlx(diag(3))
mat1 <- as_mlx(matrix(rnorm(9), 3, 3))
mat2 <- as_mlx(matrix(rnorm(9), 3, 3))
mlx_addmm(input, mat1, mat2, alpha = 0.5, beta = 2)
#> mlx array [3 x 3]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2] [,3]
#> [1,] 1.62649703 -1.907793 1.3535182
#> [2,] -0.28594804 1.693377 -0.3935536
#> [3,] -0.05022212 -3.515596 3.9770052