Feature learning networks are/become balanced
Feature learning made simple!
I’ve been going through the literature on feature-learning linear networks, with the hope that it be useful to non-linear networks. Along the way, I came across “From Lazy to Rich: Exact Learning Dynamics in Deep Linear Networks”(Dominé et al., 2025), and while reading through, it seemed like they had identified a different type of feature learning than that of the maximum-update parameterization (P) feature learning that I knew.
For reference P (Yang & Hu, 2021) tells you how you need to scale all of your network’s (scale/learning-rate) hyperparameters in order for networks to stabily learn non-trivial functions even at infinite width, with this new perspective based on the ideas of network layers being ‘balanced’ (drawing back all the way to (Du et al., 2018)), with (Kunin et al., 2024) describing how different amounts of balancedness affect feature learning. Note the balancedness picture’s lacks of mentioning stability for large width networks: as long as you make two subsequent layers have ‘a similar scale,’ that’s all you need!
In this post, I want to discuss how the ideas of P align with those of deep linear network balancedness in a unified way, and use this perspective to make predictions for how other ‘ML trickery’ that is often used in practice promotes feature learning.
Setup
Unless otherwise stated, all networks will be 2 layer (1 hidden layer) networks of the form . The input dimension of () and output dimension of (). The nonlinearity is to be taken as the identity function . The hidden dimension , is a parameter we can scale, where we wish for the network to be trainable for any value of . The network is to be updated iteratively using gradient flow (without addons such as weight decay/momentum/etc.) , , () under a square loss .1
Occasionally, I will denote and later during discussion. These are matrices with rows consisting of individual samples: .
How to achieve feature learning
P
Based on (Yang et al., 2023), we should set up our network with the following conditions:
where , . The broad intuition for this setup is such that we have (1) non-trivial updates to our weights (more specifically, the first weight matrix W), else we end up with a NNGP, and (2) not have our outputs blow up during training. Up to moving around the factors (à la mean-field parameterization), this has been identified as the only (, , )-unique scaling of our network parameters.
Balancedness
If that is true, then why does (Dominé et al., 2025) have a seemingly different rule for achieving feature learning that is removed from the P scaling? Their paper finds feature learning for a (deep) linear network can be achieved through the following condition,2 modified slightly to accomidate layerwise learning rates:
with data being whitened: .
A proof for this quantity remaining invariant can be found in both (Kunin et al., 2024) and below. With Gaussian initialization, this conserved equation is trivially achieved when the width w is large, making
where is used due to sub-leading order corrections not being included; from here on, I will be using , with the understanding that this is to leading order. Thus far, we haven’t needed to worry about making sure update sizes won’t cause trivial or blown-up updates at all due to the trajectory of our network lying on an invariant manifold defined by Eq. 1, so what gives? The trick is in the value that takes: the closer to 0, the more our solutions appear to be in the feature-learning regime.
While our solution space for feature learning is quite large, with 3 different independent knobs to turn, if we rewrite our feature learning condition in the following way,
for some small that’ll represent how close to the “true” feature learning network we get, we can see a ‘natural choice’ for how we should scale our hyperparameters; taking both portions of the equation to be invariant to our network parameters of forces , , representing a perfect match with how (Yang et al., 2023) told us to scale, up to actually choosing the layerwise learning rates and reintroducing a factor of (which I interpret as a byproduct of moving from a -rescaled diagonal identity matrix to a single scalar ).
For completeness, let’s plug in the P scaling laws to see how small is:
where and represent scalars that are with respect to the width. This quantity tends towards 0 as and tells us that any scaling done to XOR should be reflected in an equivalent change to : .
Lazy & ultra-rich scaling
On first glance, we have more freedom in the balancedness picture than the P picture, so what gives? Lazy networks have thus far been more approachable from a theoretical perspective, so surely balancedness should be our preferred lens with which to view feature learning. Why are people so much more drawn to P, outside of it being the first to introduce many of the feature learning concepts we know today?
Taking from the MFP literature, if we scale our network outputs as and the overall learning rate as for and otherwise, a scaling law discovered in (Atanasov et al., 2025), we can transition a network into the lazy regime with and into a new ‘hyper-rich’ regime of . In the balancedness picture, both the network output prefactor and learning rate rescaling can be absorbed globally; it’s unclear where to put the network output prefactor, but the learning rate can be easily included:
plugging in our P relations, this gets
Some insight that comes from this setup is that lazy networks should scale their richness paramater such that , or else we run the risk of not leaving the rich regime. No such insight appears to be able to be gleamed from the ultra-rich case, where both the width and the richness parameter contribute to lowering . Finally, our equation recovers as , as expected.
This suggests that P is not the only way to see feature-learning networks; we an also look at how balanced a network is throughout training in order to gain a deeper understanding of its feature learning properties.
Big picture: why (insert ML trick) lets you learn features
Up until now, we’ve considered an extremely simple network with no special bells or whistles. If we want to understand how different ML levers affect feature learning, we need to move beyond this, and now that we have the balancedness-P picture, we have a greater capacity to add on extra parts.
Weight decay
Take (L2) weight decay, for example: the purpose is to decrease the weights of a network. Previous papers [links] have investigated …
There doesn’t exist an invariant when training is done with weight decay (except under the explicit case that the layerwise learning rates are identical and additionally the weight decay constant is global), but we can see how our earlier constant varies with the weight penalty hyperparameter :
where we note that weight decay lowers the overall scale of weights by forcing the weights between layers to become balanced, forcing the network parameters onto a ‘more feature-learning manifold’.
Learning rate warmup
To analyze learning rate increases (and subsequent decays), we need to move away from gradient flow. We’ll still let and now have for residual . If we define , we can rewrite as,
with the general interpretation being for any initial values far from 0, learning up the learning rate allows the weights to take much faster progress towards ; this needs to be followed up by a decrease in the learning rate, or else will end up “bouncing” between the positive and negative of some final value .
Deeper network balancedness
If we want to consider deeper networks, the story is largely the same. Take . Between two subsequent layers, we have
that is, and are simultaneously diagonalizable, giving
while we can’t generally compare () due to possible shape inconsistency, we can say
with the implication being that there is a general balancedness between all layers that feature learning networks satisfy: the -th singular values of all layers in a feature learning network are all roughly equivalent.
Takeaway & Future Efforts
This blogpost was written as a dive into something I found that didn’t make sense: how can balanced nets fit into the P picture. While discussing my thoughts with Dan Kunin (whose work is the core of this post), we noted how I’m using the known P to get balacedness. Right now, I’m working on extending this in the opposite direction: can we get a different perspective on feature learning through balacedness alone, no P involved? This is (part of) my current research, so if that sounds interesting, please feel free to reach out!
Appendix
A1: Conserved balancedness throughout training
Take . Let and . For a MSE loss,
we have
taking these, we can see that ; if we take the gradient flow limit and remembering the chain rule, we have:
meaning (if we ignore my very loose definitions of and for a matrix and scalar),
References
Footnotes
-
It should be mentioned that in my experiments, I use a batch size instead of population gradient descent, although I suspect this minimally changes results. ↩
-
The paper itself details one other requirements: non-bottleneckedness/non-overparameterized. While the first condition can be trivially dropped for our discussion, we always assume > ; from what I can tell, the non-overfitting condition can be dropped for our discussion. ↩