Navigation menu

SQRLS SQRLS

The Benna-Fusi Model

In modern neural networks each learned parameter is represented as a single value; the scalar weight between an input unit and an output unit. The Benna-Fusi model asks questions like "What if it wasn't just one value?" and "What if those values could interact with each other?" The model replaces each parameter with a chain of coupled variables, with fast variables at the front and slow variables at the back. New information enters the fast end and gradually diffuses into slower, more protected variables. When we need to compute things with our parameter chain, we just use the frontmost fast variable in the chain but now it's naturally anchored to its own past.

Intuition

Communicating vessels

Benna and Fusi describe their model with an analogy to communicating vessels. Picture a row of beakers with increasing diameter connected by tubes of decreasing diameter and each filled with liquid. Each beaker represents a variable in the chain. Initially, all beakers have even fluid levels. Using a pipette we randomly add or remove a little liquid from the first beaker. The liquid flows through the tubes towards equilibrium. This balancing process varies in speed: fast with narrow beakers and wide tubes at the beginning of the chain but slower with wide beakers and narrow tubes at the end of the chain. The last beaker in the chain slowly leaks fluid through a drain. With this setup, changes to the system persist as long as the beakers' fluid levels remain out of equilibrium.

This animation shows the fluid dynamics of the chain:

Definition

The core equation describes how each synaptic variable $u_k$ evolves over time, driven by the difference between its neighbors:

$$C_k \frac{du_k}{dt} = g_{k-1,k}(u_{k-1} - u_k) + g_{k,k+1}(u_{k+1} - u_k)$$

for $k = 1, \ldots, m$, where $C_k$ is the capacity and $g_{k,k+1}$ is the coupling strength between neighboring variables. The first variable $u_1$ is the synaptic weight itself. For $u_1$, the inflow term is replaced by an external input $\mathcal{I}(t)$ (potentiation or depression). For the last variable $u_m$, we set $u_{m+1} = 0$, introducing a leak that prevents unbounded growth.

The capacity and coupling coefficients follow a power law that produces the optimal $1/\sqrt{t}$ memory decay:

$$C_k = 2^{k-1} \qquad g_{k,k+1} = 2^{-k-2}$$

Capacities grow exponentially (deeper variables are slower to change) while couplings shrink exponentially (deeper variables interact more weakly). This means that fast variables at the front of the chain capture new information quickly, while slow variables at the back protect older memories. As the flow is bidirectional, these older memories have a constant pull on the new information.

For simulation, we discretize with base $n = 2$ and step size $\alpha = 1/4$:

$$u_k(t{+}1) = u_k(t) + n^{-2k+2},\alpha,(u_{k-1} - u_k) - n^{-2k+1},\alpha,(u_k - u_{k+1})$$

The first term is the inflow from the previous variable; the second is the outflow to the next. For $u_1$, the inflow is replaced by an external input $\mathcal{I}(t)$ drawn from an Ornstein-Uhlenbeck process.

Simulation

Below, both models receive the same temporally correlated input from an Ornstein-Uhlenbeck process. The single-variable model ($m = 1$) tracks the noisy input with no memory beyond the current step. In the cascade ($m = 5$), $u_1$ is still noisy but $u_2$ through $u_5$ are progressively smoother. Each deeper variable integrates over a longer timescale. By $u_5$ the signal barely moves. It's the slow anchor that the shallower variables are regularized toward.

Connection to Deep Learning

When $m = 1$, there is no chain, just a single variable with a leak. The update becomes:

$$u_1(t{+}1) = u_1(t) + \mathcal{I}(t) - n^{-1}\alpha \cdot u_1(t)$$

If we interpret the external input as a negative gradient ($\mathcal{I}(t) = -\nabla_t$), this is exactly SGD with weight decay:

$$w_{t+1} = w_t - \nabla_t - \lambda , w_t$$

where $\lambda = n^{-1}\alpha$ is the weight decay strength. The leak on the last variable is weight decay.

The full cascade ($m > 1$) extends this in an interesting way. The outflow from $u_1$ no longer pulls it toward zero. It pulls it toward $u_2$, which itself is pulled toward $u_3$, and so on. Only the final variable $u_m$ decays toward zero. Each level of the chain regularizes the weight not toward zero but toward a progressively longer-timescale average of its past values.

This should feel familiar. We already do something like this in practice: keeping an exponential moving average of model weights for inference, or maintaining a slow target network in reinforcement learning. The cascade is a similar idea, generalized and adapted for continual learning through synaptic consolidation. In addition to usage, another difference is that an EMA decays exponentially, forgetting on a single timescale. The cascade's power-law coefficients give it $1/\sqrt{t}$ decay, which Benna and Fusi show is the slowest forgetting rate compatible with bounded weights.

Discussion

Kaplanis et al. (2018) applied this model to deep reinforcement learning and noted that the hidden variables have the effect of "regularising the value of the weight by the history of its modifications." They compare it to elastic weight consolidation and synaptic intelligence, observing that the cascade "constrains parameters to be close to their previous values" but does so over a range of timescales, without importance factors, and without knowledge of task boundaries.

The obvious limitation is specificity: every parameter in the network gets the same cascade dynamics. Some weights should consolidate quickly because they've converged to something useful, while others should stay plastic because they're still adapting. Methods like EWC address this by weighting consolidation by each parameter's importance, but they require knowing when tasks change. The cascade doesn't, which is both its appeal and its weakness. A natural extension would be to modulate the coupling strengths per-parameter based on some importance signal, though this starts to lose the model's simplicity.

There's also a question of which side of the plasticity-stability tradeoff this really helps. The standard framing is catastrophic forgetting, too much plasticity destroying old knowledge. But the cascade might also help with the opposite problem: maintaining plasticity in later training when weights have settled into a basin. The slow variables act as a stabilizing anchor, but the fast variable $u_1$ remains free to explore.

References

  • Benna, M.K. & Fusi, S. Computational principles of synaptic memory consolidation. Nature Neuroscience 19, 1697–1706 (2016).
  • Kaplanis, C., Shanahan, M. & Clopath, C. Continual reinforcement learning with complex synapses. ICML (2018).
  • Kirkpatrick, J. et al. Overcoming catastrophic forgetting in neural networks. PNAS 114(13), 3521–3526 (2017).
  • Zenke, F., Poole, B. & Ganguli, S. Continual learning through synaptic intelligence. ICML (2017).