Navigation menu

SQRLS SQRLS

Autoencoders as Associative Memory

No WebGPU detected. This demo runs faster with WebGPU enabled.

An overparameterized autoencoder, trained only to reconstruct its inputs, does something interesting when you iterate it: feed it a novel image and loop the output back as input, and it converges to the nearest training example. Reconstruction becomes retrieval. Radhakrishnan et al. (2020) showed that these networks develop stable attractors at their training points, despite never being trained for convergence.

The setup below trains an autoencoder (~466k parameters) on just 16 MNIST digits with heavy augmentation, implemented in jax-js so you can watch the attractors form live in your browser.

graph LR
    X[Input 28×28] --> C[Conv2D 7×7
stride 7] --> F[Flatten
256 dims] F --> E[GLU Block
+ LayerNorm] E --> Z[Latent z
256 dims] Z --> D[GLU Block
+ LayerNorm] D --> P[Linear
256→784] P --> Y[Output 28×28] Y -.Iterate.-> X
Encoder-decoder with a 256-d bottleneck. At inference, the output is fed back in as the next input, for 10 iterations.

To see retrieval in action, we track how far the first test image sits from its nearest training attractor at each of 10 iteration steps, plotted over the course of training and colored from iteration 1 (dark) to iteration 10 (light). Early on, the traces sit on top of each other: the attractors haven't formed yet, so iterating doesn't move the input anywhere useful. As training progresses and the learning rate decays, the traces fan out. Later iterations drop well below iteration 1, which is what convergence into a basin looks like from this view.

Distance to the nearest attractor at each iteration (first test image), plotted across training. Widening spread = stronger convergence.

That's one image's trace. The scatter below zooms out to show every training attractor and every test path at once, animated across training steps. The attractors aren't fixed. They form gradually and keep shifting as training continues, with inputs landing in different basins from step to step, only settling as the learning rate decays. You can basically watch catastrophic forgetting play out in real time.

Latent space (256-d) projected onto three orthogonal planes. Colored dots: training attractors. Paths: test images converging through iteration.

That clean basin structure isn't automatic. Without heavy augmentation, intermediate outputs drift out of distribution and the network develops spurious attractors: fixed points the decoder produces that aren't any of the training images. Augmentation smooths the input distribution and pulls the basins back onto the real data. Below, each test image iterates toward the training example it most resembles.

Test 1
Rollout
Attractor
Test 2
Rollout
Attractor
Test 3
Rollout
Attractor
Test 4
Rollout
Attractor

The reconstruction loss puts fixed points at the training images by construction. Heavy augmentation and the smoothness of the learned map combine to give those fixed points wide basins, wide enough to catch novel inputs and pull them toward the training image they most resemble in latent space. Iterate the network and reconstruction becomes retrieval.

References

  • Radhakrishnan, A., Belkin, M. & Uhler, C. Memorization in overparameterized autoencoders. arXiv (2018).
  • Radhakrishnan, A., Belkin, M. & Uhler, C. Overparameterized neural networks implement associative memory. PNAS 117(44), 27162–27170 (2020).