Skip to contents

Normalizes inputs across the batch dimension.

Usage

mlx_batch_norm(
  num_features,
  eps = 1e-05,
  momentum = 0.1,
  device = mlx_default_device()
)

Arguments

num_features

Number of feature channels.

eps

Small constant for numerical stability (default: 1e-5).

momentum

Momentum for running statistics (default: 0.1).

device

Execution target: supply "gpu", "cpu", or an mlx_stream created via mlx_new_stream(). Default: mlx_default_device().

Value

An mlx_module applying batch normalization.

See also

Examples

set.seed(1)
bn <- mlx_batch_norm(4)
x <- as_mlx(matrix(rnorm(12), 3, 4))
mlx_forward(bn, x)
#> mlx array [3 x 4]
#>   dtype: float32
#>   device: gpu
#>   values:
#>            [,1]        [,2]       [,3]       [,4]
#> [1,] -0.4556868  1.24383128 -1.0877743 -1.1186367
#> [2,]  1.3872330 -0.03912285  1.3256620  1.3086261
#> [3,] -0.9315463 -1.20470834 -0.2378886 -0.1899893