Computes beta * input + alpha * (mat1 %*% mat2) in a single MLX kernel.
All operands are promoted to a common dtype/device prior to evaluation.
Examples
input <- as_mlx(diag(3))
mat1 <- as_mlx(matrix(rnorm(9), 3, 3))
mat2 <- as_mlx(matrix(rnorm(9), 3, 3))
mlx_addmm(input, mat1, mat2, alpha = 0.5, beta = 2)
#> mlx array [3 x 3]
#> dtype: float32
#> device: gpu
#> values:
#> [,1] [,2] [,3]
#> [1,] 3.4694786 -1.947971 0.7083434
#> [2,] -0.1372439 2.180840 -0.1373530
#> [3,] -0.8607467 1.066743 1.3697486