Normalizes inputs across the batch dimension.
Examples
set.seed(1)
bn <- mlx_batch_norm(4)
x <- as_mlx(matrix(rnorm(12), 3, 4))
mlx_forward(bn, x)
#> mlx array [3 x 4]
#> dtype: float32
#> values:
#> [,1] [,2] [,3] [,4]
#> [1,] -0.4556868 1.24383128 -1.0877743 -1.1186367
#> [2,] 1.3872330 -0.03912285 1.3256620 1.3086261
#> [3,] -0.9315463 -1.20470834 -0.2378886 -0.1899893