I noticed that batch normalization is missing in the `MLP` and `ShapedMLP` backbone, we should probably add it to both implementations.