Feature learning networks are/become balanced

Feature learning made simple!

Oct 18, 2025

I’ve been going through the literature on feature-learning linear networks, with the hope that it be useful to non-linear networks. Along the way, I came across “From Lazy to Rich: Exact Learning Dynamics in Deep Linear Networks(Dominé et al., 2025), and while reading through, it seemed like they had identified a different type of feature learning than that of the maximum-update parameterization (μ\muP) feature learning that I knew.

For reference μ\muP (Yang & Hu, 2021) tells you how you need to scale all of your network’s (scale/learning-rate) hyperparameters in order for networks to stabily learn non-trivial functions even at infinite width, with this new perspective based on the ideas of network layers being ‘balanced’ (drawing back all the way to (Du et al., 2018)), with (Kunin et al., 2024) describing how different amounts of balancedness affect feature learning. Note the balancedness picture’s lacks of mentioning stability for large width networks: as long as you make two subsequent layers have ‘a similar scale,’ that’s all you need!

In this post, I want to discuss how the ideas of μ\muP align with those of deep linear network balancedness in a unified way, and use this perspective to make predictions for how other ‘ML trickery’ that is often used in practice promotes feature learning.

Setup

Unless otherwise stated, all networks will be 2 layer (1 hidden layer) networks of the form f(x)=iaiσ(jWijxj)\mathbf{f}(\mathbf{x}) = \sum_{i} \mathbf{a}_i \sigma(\sum_jW_{ij}x_j). The input dimension of dd (xRdx \in \mathbb{R}^{d}) and output dimension of doutd_{out} (f(x),yRdout\mathbf{f}(\mathbf{x})\,, \mathbf{y} \in \mathbb{R}^{d_{out}}). The nonlinearity is to be taken as the identity function σ:xx\sigma: \mathbf{x}\mapsto \mathbf{x}. The hidden dimension ww, WRw×dW \in \mathbb{R}^{w \times d} is a parameter we can scale, where we wish for the network to be trainable for any value of ww. The network is to be updated iteratively using gradient flow (without addons such as weight decay/momentum/etc.) Δa=ηaLa\Delta \mathbf{a} = -\eta_a \frac{\partial \mathcal{L}}{\partial \mathbf{a}}, ΔW=ηWLW\Delta W = -\eta_W \frac{\partial \mathcal{L}}{\partial W}, (ηa,ηW1\eta_a, \eta_W \ll 1) under a square loss L=(f(x)y)2\mathcal{L} = (\mathbf{f}(\mathbf{x})-\mathbf{y})^2.1

Occasionally, I will denote XRN×dX \in \mathcal{R}^{N\times d} and YRN×doutY \in \mathcal{R}^{N\times d_{out}} later during discussion. These are matrices with rows consisting of individual samples: Xi/Yi=x(i)/y(i)X_i/Y_i = \mathbf{x}^{(i)}/\mathbf{y}^{(i)}.

How to achieve feature learning

μ\muP

Based on (Yang et al., 2023), we should set up our network with the following conditions:

σW=Θ(1d)σa=Θ(doutw2)ηW=Θ(wd)ηa=Θ(doutw)\begin{align*} \sigma_W &= \Theta\left(\sqrt{\frac{1}{d}}\right)\\ \sigma_a &= \Theta\left(\sqrt{\frac{d_{out}}{w^2}}\right)\\ \eta_W &= \Theta\left(\frac{w}{d}\right)\\ \eta_a &= \Theta\left(\frac{d_{out}}{w}\right) \end{align*}

where aiN(0,σa2)\mathbf{a}_i \sim \mathcal{N}(0, \sigma_a^2), Wi,jN(0,σW2)W_{i,j} \sim \mathcal {N}(0, \sigma_W^2). The broad intuition for this setup is such that we have (1) non-trivial updates to our weights (more specifically, the first weight matrix W), else we end up with a NNGP, and (2) not have our outputs blow up during training. Up to moving around the factors (à la mean-field parameterization), this has been identified as the only (dd, ww, doutd_{out})-unique scaling of our network parameters.

Balancedness

If that is true, then why does (Dominé et al., 2025) have a seemingly different rule for achieving feature learning that is removed from the μ\muP scaling? Their paper finds feature learning for a (deep) linear network can be achieved through the following condition,2 modified slightly to accomidate layerwise learning rates:

At init, 1ηaaa1ηWWW=λIw,\begin{equation} \text{At init, } \frac{1}{\eta_a}\mathbf{a}^\top \mathbf{a} - \frac{1}{\eta_W} WW^\top = \lambda I_w, \end{equation}

with data being whitened: XX=IdX^\top X = I_d.

A proof for this quantity remaining invariant can be found in both (Kunin et al., 2024) and below. With Gaussian initialization, this conserved equation is trivially achieved when the width w is large, making

1ηaaa1ηWWW(1ηadoutσa21ηWdσW2)IwλIw,\begin{equation*} % \mathbf{a}^\top\mathbf{a} &\sim d_{out}\sigma_a^2 I_w\\ % WW^\top &\sim d\sigma_W^2 I_w\\ \frac{1}{\eta_a}\mathbf{a}^\top\mathbf{a}-\frac{1}{\eta_W}WW^\top\sim\left(\frac{1}{\eta_a} d_{out}\sigma_a^2-\frac{1}{\eta_W} d\sigma_W^2\right)I_w \equiv \lambda I_w, \end{equation*}

where \sim is used due to sub-leading order corrections not being included; from here on, I will be using ==, with the understanding that this is to leading order. Thus far, we haven’t needed to worry about making sure update sizes won’t cause trivial or blown-up updates at all due to the trajectory of our network lying on an invariant manifold defined by Eq. 1, so what gives? The trick is in the value that λ\lambda takes: the closer to 0, the more our solutions appear to be in the feature-learning regime.

While our solution space for feature learning is quite large, with 3 different independent knobs to turn, if we rewrite our feature learning condition in the following way,

λ=1ηadoutσa21ηWdσW2=doutηaσa2dηWσW2ϵ\begin{equation*} \lambda = \frac{1}{\eta_a} d_{out}\sigma_a^2 - \frac{1}{\eta_W} d\sigma_W^2 = \frac{d_{out}}{\eta_a}\sigma_a^2 - \frac{d}{\eta_W}\sigma_W^2 \leq \epsilon \end{equation*}

for some small ϵ\epsilon that’ll represent how close to the “true” feature learning network we get, we can see a ‘natural choice’ for how we should scale our hyperparameters; taking both portions of the equation to be invariant to our network parameters of (d,dout,w)(d, d_{out}, w) forces σa2ηa/dout\sigma_a^2 \sim \eta_a/d_{out}, σW2ηW/d\sigma_W^2 \sim \eta_W/d, representing a perfect match with how (Yang et al., 2023) told us to scale, up to actually choosing the layerwise learning rates and reintroducing a factor of ww (which I interpret as a byproduct of moving from a λ\lambda-rescaled diagonal w×ww\times w identity matrix to a single scalar λ\lambda).

For completeness, let’s plug in the μ\muP scaling laws to see how small λ\lambda is:

λ=Θ(wdoutdoutdoutw2dwd1d)=f2doutwf1dw=f2doutf1dw\begin{equation*} \lambda = \Theta\left(\frac{w}{d_{out}}d_{out}\frac{d_{out}}{w^2}-\frac{d}{w}d\frac{1}{d}\right)=\frac{f_2d_{out}}{w}-\frac{f_1d}{w}=\frac{f_2d_{out}-f_1d}{w} \end{equation*}

where f2f_2 and f1f_1 represent scalars that are Θ(1)\Theta(1) with respect to the width. This quantity tends towards 0 as ww\rightarrow \infty and tells us that any scaling done to dd XOR doutd_{out} should be reflected in an equivalent change to ww: wd(out)w \sim d(_{out}).

Lazy & ultra-rich scaling

On first glance, we have more freedom in the balancedness picture than the μ\muP picture, so what gives? Lazy networks have thus far been more approachable from a theoretical perspective, so surely balancedness should be our preferred lens with which to view feature learning. Why are people so much more drawn to μ\muP, outside of it being the first to introduce many of the feature learning concepts we know today?

Taking from the MFP literature, if we scale our network outputs as ff/γ f \mapsto f/\gamma and the overall learning rate as ηηγ2\eta \mapsto \eta\cdot\gamma^2 for γ1\gamma \leq 1 and ηηγ\eta \mapsto \eta\cdot\gamma otherwise, a scaling law discovered in (Atanasov et al., 2025), we can transition a network into the lazy regime with γ1\gamma \ll 1 and into a new ‘hyper-rich’ regime of γ1\gamma \gg 1. In the balancedness picture, both the network output prefactor and learning rate rescaling can be absorbed globally; it’s unclear where to put the network output prefactor, but the learning rate can be easily included:

λ{(1ηadoutσa21ηWdσW2)/γ2,γ1(1ηadoutσa21ηWdσW2)/γ,γ > 1.\lambda \sim \begin{cases} \left(\frac{1}{\eta_a} d_{out}\sigma_a^2-\frac{1}{\eta_W} d\sigma_W^2\right)/\gamma^2, & \gamma \leq 1\\ \left(\frac{1}{\eta_a} d_{out}\sigma_a^2-\frac{1}{\eta_W} d\sigma_W^2\right)/\gamma, & \gamma \text{ > } 1. \end{cases}

plugging in our μ\muP relations, this gets

λ(ω)=Θ(doutdω),ω{wγ2,γ1wγ,γ > 1.\lambda(\omega) = \Theta\left(\frac{d_{out}-d}{\omega}\right),\\ \omega \equiv \begin{cases} w\gamma^2, & \gamma \leq 1\\ w\gamma, & \gamma \text { > } 1. \end{cases}

Some insight that comes from this setup is that lazy networks should scale their richness paramater γ\gamma such that γw\gamma \gg \sqrt{w}, or else we run the risk of not leaving the rich regime. No such insight appears to be able to be gleamed from the ultra-rich case, where both the width and the richness parameter contribute to lowering λ\lambda. Finally, our equation recovers λ0(±)\lambda \rightarrow 0 (\pm\infty) as γ(0)\gamma \rightarrow \infty(0), as expected.

This suggests that μ\muP is not the only way to see feature-learning networks; we an also look at how balanced a network is throughout training in order to gain a deeper understanding of its feature learning properties.

Big picture: why (insert ML trick) lets you learn features

Up until now, we’ve considered an extremely simple network with no special bells or whistles. If we want to understand how different ML levers affect feature learning, we need to move beyond this, and now that we have the balancedness-μ\muP picture, we have a greater capacity to add on extra parts.

Weight decay

Take (L2) weight decay, for example: the purpose is to decrease the weights of a network. Previous papers [links] have investigated …

There doesn’t exist an invariant when training is done with weight decay (except under the explicit case that the layerwise learning rates are identical and additionally the weight decay constant is global), but we can see how our earlier constant λ\lambda varies with the weight penalty hyperparameter ζ\zeta:

dλdt=2ζ(aFWF)\begin{equation*} \frac{\text{d}\lambda}{\text{d}t} = -2\zeta\left(\lvert\lvert \mathbf{a}\rvert\rvert_F-\lvert\lvert W\rvert\rvert_F\right) \end{equation*}

where we note that weight decay lowers the overall scale of weights by forcing the weights between layers to become balanced, forcing the network parameters onto a ‘more feature-learning manifold’.

Learning rate warmup

To analyze learning rate increases (and subsequent decays), we need to move away from gradient flow. We’ll still let λ=aF2/ηaWF2/ηW\lambda = \lvert\lvert a\rvert\rvert_F^2/\eta_a-\lvert\lvert W\rvert\rvert_F^2/\eta_W and now have Δλ=ηaR(WX)F2ηWaRXF2\Delta\lambda = \eta_a\lvert\lvert R(WX)^\top\rvert\rvert_F^2-\eta_W\lvert\lvert \mathbf{a}^\top R X^\top\rvert\rvert_F^2 for residual RYaWXR \equiv Y-aWX. If we define M=XRM = XR^\top, we can rewrite Δλ\Delta \lambda as,

Δλ(t)=ηa(t)αtηW(t)βtηaηWMF2λαt=WMF2βt=aMF2\begin{align*} \Delta \lambda(t) &= \eta_a(t) \alpha_t - \eta_W(t)\beta_t \leq -\eta_a\eta_W \lvert\lvert M\rvert\rvert_F^2 \lambda\\ \alpha_t &= \lvert\lvert WM\rvert\rvert_F^2\\ \beta_t &= \lvert\lvert a^\top M\rvert\rvert_F^2 \end{align*}

with the general interpretation being for any initial values λ\lambda far from 0, learning up the learning rate allows the weights to take much faster progress towards λ=0\lambda=0; this needs to be followed up by a decrease in the learning rate, or else λ\lambda will end up “bouncing” between the positive and negative of some final value ±λf\pm \lambda_f.

Deeper network balancedness

If we want to consider deeper networks, the story is largely the same. Take Y=WLWL1W2W1XY=W_LW_{L-1}\cdots W_2W_1X. Between two subsequent layers, we have

1η+1W+1W+11ηWW=λIw\begin{equation*} \frac{1}{\eta_{\ell+1}}W_{\ell+1}^\top W_{\ell+1}-\frac{1}{\eta_\ell}W_{\ell} W_{\ell}^\top = \lambda_\ell I_{w_\ell} \end{equation*}

that is, W+1W+1W_{\ell+1}^\top W_{\ell+1} and WWW_{\ell} W_{\ell}^\top are simultaneously diagonalizable, giving

σi(W+1)2=σi(W)2+λ\begin{equation*} \sigma_i(W_{\ell+1})^2= \sigma_i(W_{\ell})^2+\lambda_\ell \end{equation*}

while we can’t generally compare 1ηiWiWi1ηjWjWj\frac{1}{\eta_i} W_{i}^\top W_{i} - \frac{1}{\eta_j} W_{j} W_{j}^\top (ji1j \neq i-1) due to possible shape inconsistency, we can say

σi(W+1)2=σi(Wk)2+i=k1λi\begin{equation*} \sigma_i(W_{\ell+1})^2= \sigma_i(W_{k})^2+\sum_{i=k}^{\ell-1}\lambda_i \end{equation*}

with the implication being that there is a general balancedness between all layers that feature learning networks satisfy: the ii-th singular values of all layers in a feature learning network are all roughly equivalent.

Takeaway & Future Efforts

This blogpost was written as a dive into something I found that didn’t make sense: how can balanced nets fit into the mumuP picture. While discussing my thoughts with Dan Kunin (whose work is the core of this post), we noted how I’m using the known μ\muP to get balacedness. Right now, I’m working on extending this in the opposite direction: can we get a different perspective on feature learning through balacedness alone, no μ\muP involved? This is (part of) my current research, so if that sounds interesting, please feel free to reach out!

Appendix

A1: Conserved balancedness throughout training

Take y=aWxy=aWx. Let δW=ηWLW\delta W = -\eta_W\frac{\partial L}{\partial W} and δa=ηaLa\delta a = -\eta_a\frac{\partial L}{\partial a}. For a MSE loss,

L=(yaWx)2\begin{equation*} L = (y-aWx)^2 \end{equation*}

we have

LWij=2(yaWx)(aiδijxj)Lai=2(yaWx)(Wijxj)\begin{align*} \frac{\partial L}{\partial W_{ij}} &= 2(y-aWx)(a_i \delta_{ij} x_j)\\ \frac{\partial L}{\partial a_i} &= 2(y-aWx)(W_{ij} x_j) \end{align*}

taking these, we can see that WδWηWaδaηa\frac{W \delta W}{\eta_W} \sim \frac{a \delta a}{\eta_a}; if we take the gradient flow limit and remembering the chain rule, we have:

(W2/ηWa2/ηa)t=0\begin{equation*} \frac{\partial(W^2/\eta_W-a^2/\eta_a)}{\partial t} = 0 \end{equation*}

meaning (if we ignore my very loose definitions of W2W^2 and a2a^2 for a matrix and scalar),

W2/ηWa2/ηa=const.\begin{equation*} W^2/\eta_W-a^2/\eta_a = \text{const}. \end{equation*}

References

Atanasov, A., Meterez, A., Simon, J. B., & Pehlevan, C. (2025). The Optimization Landscape of SGD Across the Feature Learning Strength. https://arxiv.org/abs/2410.04642
Dominé, C. C. J., Anguita, N., Proca, A. M., Braun, L., Kunin, D., Mediano, P. A. M., & Saxe, A. M. (2025). From Lazy to Rich: Exact Learning Dynamics in Deep Linear Networks. https://arxiv.org/abs/2409.14623
Du, S. S., Hu, W., & Lee, J. D. (2018). Algorithmic Regularization in Learning Deep Homogeneous Models: Layers are Automatically Balanced. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, & R. Garnett (Eds.), Advances in Neural Information Processing Systems (Vol. 31). Curran Associates, Inc. https://proceedings.neurips.cc/paper_files/paper/2018/file/fe131d7f5a6b38b23cc967316c13dae2-Paper.pdf
Kunin, D., Raventós, A., Dominé, C., Chen, F., Klindt, D., Saxe, A., & Ganguli, S. (2024). Get rich quick: exact solutions reveal how unbalanced initializations promote rapid feature learning. https://arxiv.org/abs/2406.06158
Yang, G., & Hu, E. J. (2021). Tensor programs iv: Feature learning in infinite-width neural networks. International Conference on Machine Learning, 11727–11737. https://proceedings.mlr.press/v139/yang21c.html
Yang, G., Simon, J. B., & Bernstein, J. (2023). A spectral condition for feature learning. arXiv Preprint arXiv:2310.17813.

Footnotes

  1. It should be mentioned that in my experiments, I use a batch size instead of population gradient descent, although I suspect this minimally changes results.

  2. The paper itself details one other requirements: non-bottleneckedness/non-overparameterized. While the first condition can be trivially dropped for our discussion, we always assume ww > max(d,dout)\max(d, d_{out}); from what I can tell, the non-overfitting condition can be dropped for our discussion.