Federated Learning with Trainable Prototypes
No WebGPU detected. This demo runs faster with WebGPU enabled.
Federated learning usually works by averaging model weights across clients. But that only makes sense if all clients have the same architecture. In the heterogeneous case — each organization running whatever model fits their hardware and data — weight averaging is off the table.
One natural fallback: share class prototypes, the per-class mean feature vectors from each client's encoder. Every model can produce them, regardless of architecture. The server averages them across clients and sends them back as supervision. Clients then train their local embeddings to agree with the global prototypes.
The problem is that simple weighted averaging of per-client prototypes tends to be suboptimal. Each client's encoder settles at a different local optimum, so even class-zero prototypes from two different clients may point in very different directions in the shared prototype space. The naive average lands somewhere in the middle — not ideal for either client.
Zhang et al. (2024) make a small but effective observation: the aggregation step doesn't have to be passive. Instead of computing a weighted mean, the server can train the global prototypes — running a short optimization loop that pulls same-class prototypes from different clients together and pushes different-class prototypes apart.
Adaptive-Margin Contrastive Learning
The server receives per-class prototypes from each client. It then minimizes:
$$ \mathcal{L} = \underbrace{\sum_k \left[|g_k - p_k^A|^2 + |g_k - p_k^B|^2\right]}{\text{attraction}} + \underbrace{\sum{i \neq j} \max!\left(0,; m_{ij} - |g_i - g_j|\right)^2}_{\text{repulsion}} $$
where $g_k$ are the learnable global prototypes, $p_k^A$ and $p_k^B$ are the client prototypes for class $k$, and $m_{ij}$ is an adaptive margin between class $i$ and class $j$. The attraction term keeps global prototypes anchored near the client evidence; the repulsion term ensures different classes stay separated, with larger margins for classes that are harder to distinguish.
After training, clients receive the refined global prototypes and add an alignment term to their local loss — pulling each example's embedding toward its class's global prototype. Over rounds, the client embeddings and the global prototypes converge on a shared representation.
The full federated round looks like this:
graph LR
A[Client A
trains locally] -->|"prototypes {p_k^A}"| S[Aggregator
runs ACL]
B[Client B
trains locally] -->|"prototypes {p_k^B}"| S
S -->|"global prototypes {g_k}"| A
S -->|"global prototypes {g_k}"| B
A -.alignment loss.-> A
B -.alignment loss.-> B
Demo
The simulation below runs two federated systems side by side — both starting from the same initial weights, using the same client data.
Setup: 4 MNIST digit classes (0–3), two clients with skewed distributions. Client A trains on 48 examples each of digits 0 and 1, and 16 each of 2 and 3. Client B has the opposite skew. Each client runs a small encoder (784 → 64 → 16 → 2) that maps images directly into 2D, so the embeddings and prototypes are visible without projection.
In the scatter plots: small dots are test image embeddings colored by true class; triangles are Client A's local prototypes; squares are Client B's; large circles with white outlines are the global prototypes. Watch what happens to those circles over rounds.
What to look for: In FedAvg, the four global prototypes (large circles) settle at weighted averages of the client prototypes — often bunched together or misaligned with the actual test clusters. In FedTGP, the contrastive training pushes them apart: the circles end up in the center of each class cluster rather than between the two clients' raw averages. The accuracy gap between the two lines on the chart is the direct consequence.
The margin slider controls how aggressively the repulsion acts. Too small and classes bleed into each other; too large and the prototypes get pushed so far apart they lose contact with the actual data.
Results
On twelve heterogeneous model architectures across multiple datasets, FedTGP improves accuracy by up to 9.08% over the best previous prototype-based federated method — without any change to what is communicated between clients and server, and with no access to raw client data.
The mechanism is simple enough to implement in an afternoon, which is part of why it's compelling. The server-side training loop is just gradient descent on a small set of learned vectors. The insight — that aggregation can be an optimization problem rather than a closed-form average — generalizes well beyond the specific contrastive objective used here.
References
- Zhang, Y., Hua, Y., Wang, T., Song, H., & Xue, Z. FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning. AAAI (2024). arXiv:2401.03230