Internal covariate shift
As training progresses, early layer weights change, shifting the distribution of inputs to later layers. Layer 2 might receive inputs with mean=0.5 at epoch 1 but mean=2.1 at epoch 10. Each layer must constantly re-adapt to a moving target -- slowing convergence and causing instability.
| Epoch | Layer 2 input mean | Layer 2 input std | Effect |
|---|---|---|---|
| 1 | 0.5 | 1.2 | Layer 2 calibrated for this range |
| 10 | 2.1 | 3.8 | Distribution shifted -- layer must re-adapt |
| 20 | -0.3 | 0.4 | Shifted again -- training is unstable |
What BatchNormalization does
For each mini-batch, BatchNormalization() normalises each feature to mean=0, std=1, then applies learned scale and shift parameters:
| Step | Operation | Formula |
|---|---|---|
| 1 | Compute batch statistics | mean and variance of each feature across the batch |
| 2 | Normalise | z = (x - mean) / sqrt(var + epsilon) |
| 3 | Scale and shift | output = gamma * z + beta (gamma and beta are learned) |
The result: each layer always receives inputs with approximately mean=0 and std=1, regardless of what happened in earlier layers.
Placement: Dense then BatchNorm then activation
The recommended placement is between the Dense layer and its activation function. This normalises the values before the non-linearity is applied.
model = keras.Sequential([
# Block 1
keras.layers.Dense(256, input_shape=(10,)), # no activation here
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.Dropout(0.3),
# Block 2
keras.layers.Dense(256),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.Dropout(0.3),
# Block 3
keras.layers.Dense(256),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
keras.layers.Dropout(0.3),
# Output
keras.layers.Dense(1, activation='sigmoid'),
])
BatchNorm benefits
| Benefit | How | Practical impact |
|---|---|---|
| Faster training | Stable input distributions allow higher learning rates | Reach convergence in fewer epochs |
| Reduced sensitivity to initialisation | BatchNorm re-centres values regardless of starting weights | Less time tuning initialisation |
| Mild regularisation | Batch statistics add noise (each batch has slightly different mean/var) | Works alongside dropout |
| Smoother loss curves | Gradients flow more evenly through normalised layers | Easier to diagnose training issues |
Think Deeper
Without BatchNorm, layer 2 receives inputs with mean=0.5 at epoch 1 but mean=2.1 at epoch 10. How does this 'internal covariate shift' affect a malware classifier in practice?