Step 3: Batch Normalisation

Stabilise training with normalised activations

1 ExplorePlay below
2 ReadUnderstand
3 BuildHands-on lab
4 CompareSolution
💡 ReflectThink deeper

Internal covariate shift

As training progresses, early layer weights change, shifting the distribution of inputs to later layers. Layer 2 might receive inputs with mean=0.5 at epoch 1 but mean=2.1 at epoch 10. Each layer must constantly re-adapt to a moving target -- slowing convergence and causing instability.

EpochLayer 2 input meanLayer 2 input stdEffect
10.51.2Layer 2 calibrated for this range
102.13.8Distribution shifted -- layer must re-adapt
20-0.30.4Shifted again -- training is unstable

What BatchNormalization does

For each mini-batch, BatchNormalization() normalises each feature to mean=0, std=1, then applies learned scale and shift parameters:

StepOperationFormula
1Compute batch statisticsmean and variance of each feature across the batch
2Normalisez = (x - mean) / sqrt(var + epsilon)
3Scale and shiftoutput = gamma * z + beta (gamma and beta are learned)

The result: each layer always receives inputs with approximately mean=0 and std=1, regardless of what happened in earlier layers.

Placement: Dense then BatchNorm then activation

The recommended placement is between the Dense layer and its activation function. This normalises the values before the non-linearity is applied.

model = keras.Sequential([
    # Block 1
    keras.layers.Dense(256, input_shape=(10,)),   # no activation here
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Dropout(0.3),

    # Block 2
    keras.layers.Dense(256),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Dropout(0.3),

    # Block 3
    keras.layers.Dense(256),
    keras.layers.BatchNormalization(),
    keras.layers.Activation('relu'),
    keras.layers.Dropout(0.3),

    # Output
    keras.layers.Dense(1, activation='sigmoid'),
])

BatchNorm benefits

BenefitHowPractical impact
Faster trainingStable input distributions allow higher learning ratesReach convergence in fewer epochs
Reduced sensitivity to initialisationBatchNorm re-centres values regardless of starting weightsLess time tuning initialisation
Mild regularisationBatch statistics add noise (each batch has slightly different mean/var)Works alongside dropout
Smoother loss curvesGradients flow more evenly through normalised layersEasier to diagnose training issues
Loading...
Loading...
Loading...

Think Deeper

Without BatchNorm, layer 2 receives inputs with mean=0.5 at epoch 1 but mean=2.1 at epoch 10. How does this 'internal covariate shift' affect a malware classifier in practice?

Layer 2 must constantly re-adapt to a moving target distribution, slowing convergence and requiring a lower learning rate. For a malware classifier, this means longer training times and more sensitivity to weight initialisation. BatchNorm fixes the input distribution at each layer, allowing higher learning rates and faster, more stable training.
Cybersecurity tie-in: In security ML, models are often retrained on fresh data as the threat landscape evolves. BatchNorm makes retraining more stable and predictable -- you can confidently update a malware classifier weekly without worrying that shifting feature distributions will cause the model to diverge during fine-tuning.

Loading...