Redesigning a classic multimodal learning tool for the age of distributed data - MBZUAI MBZUAI

Redesigning a classic multimodal learning tool for the age of distributed data

Monday, May 04, 2026

When two hospitals want to learn how MRI scans relate to genomic profiles, they face an uncomfortable choice. They can pool their data in one place and run a standard statistical analysis, which works well but requires patients’ sensitive records to leave the institutions that collected them. Or they can keep everything local and accept that certain kinds of joint analysis simply can’t be done. For decades, this tradeoff has been treated as more or less fundamental.

A team of researchers from MBZUAI has proposed a way around this, with a framework called FedCCA, which will be presented at AISTATS 2026 in Tangier, Morocco. It tackles a specific but broadly useful statistical problem: canonical correlation analysis, or CCA, which finds the shared structure between two different representations of the same data. Think of it as asking, “What do these two views of the world agree on?” The technique dates back to Harold Hotelling’s work in the 1930s and remains a workhorse in fields from computer vision to neuroscience.

The trouble is that CCA, in its classical form, requires you to invert large covariance matrices which is expensive even on a single machine. Prior attempts at distributed CCA either demanded that a central server perform heavy linear algebra on pooled statistics, or applied to a different data arrangement entirely: the “vertical” setting, where each party holds different features for the same people, rather than the “horizontal” case, where each party holds different people described by the same features.

The algebraic insight

The core idea in FedCCA is not a new neural architecture or a novel optimization algorithm. Instead, the authors replace the matrix inversions at the heart of CCA with a truncated von Neumann series, a classical result from functional analysis that lets you express the inverse of certain matrices as an infinite sum of powers. Cut that sum off after a few terms and you get an approximation whose error shrinks geometrically with each additional term you keep.

This might sound like a minor algebraic rearrangement, and in a sense it is but the consequences are significant. Matrix inversions are monolithic operations: you need the whole matrix in one place, and the computation doesn’t decompose neatly. A truncated power series, on the other hand, is just a sequence of matrix-vector multiplications. Each multiplication can be performed locally by a client on its own data, and only the low-dimensional result needs to be sent to a central server for aggregation. The server sees compressed projections whose dimensionality, k, is chosen to be far smaller than the data dimension or the number of samples on any client.

Assistant Professor of Machine Learning, Zhiqiang Xu – one of the authors of the paper – calls this the “alternating matrix-vector multiplication” scheme, or AMVM. At each step, the server broadcasts the current estimate of the projection matrices, clients multiply their local data by those matrices and send back the results, and the server assembles the next iterate. The two views remain fixed throughout optimization, while the algorithm alternates between updating their corresponding projection matrices: the projection for one view is updated using the current projection of the other view, and vice versa, until convergence.

Privacy, formally

Low-dimensional projections already make it hard for a curious server to reconstruct raw data. A single round of communication gives the server far fewer equations than unknowns. But “hard” is not “impossible.” A malicious server that issues enough carefully chosen queries to the same client could, in principle, reconstruct that client’s data. The underdetermined system becomes determined once the server accumulates enough independent projections.

To close this gap, FedCCA offers an extension called FedCCA-DP, where each client adds calibrated Gaussian noise to its outgoing projections. The framework comes with two theorems that together define a feasibility window for the noise. Theorem 2 establishes a lower bound: how much noise you need to guarantee (ε, δ)-differential privacy, the standard formal definition that bounds what any adversary can infer about any individual record. Theorem 3 establishes an upper bound: how much noise the optimization can tolerate before convergence breaks down. If the lower bound exceeds the upper bound, you’re stuck. If it doesn’t, you have a workable range, and practitioners can choose where within that range to operate based on how much they value privacy versus accuracy.

The privacy-accuracy tradeoff has a clean structure. The series order m controls approximation quality: higher m means less truncation error but also more queries per round, which means more noise must be composed across the protocol. The total number of private queries scales as 4(2+2m)TM, where T is the number of outer iterations and M is the number of clients. The per-query noise doesn’t depend on m, but the accumulated privacy loss does, linearly. So m acts as a dial that simultaneously tunes two competing objectives.

What the experiments show

The authors test on five datasets spanning multimedia annotation (Mediamill), speech articulation (JW11), handwritten digits (MNIST), multi-feature object recognition (MFEAT), and natural image classification (Caltech101). These are standard CCA benchmarks, not large-scale industrial deployments.

Against the two main baselines that are actually feasible in a horizontal federated setting with differential privacy (alternating least squares and its “truly alternating” variant), FedCCA performs well. On MNIST and several other datasets, it achieves sub-optimality gaps that are orders of magnitude lower, meaning its estimates of the canonical correlations are much closer to the true optimum. It also converges faster, requiring roughly 20% fewer passes through the data on average.

On JW11, MNIST, and MFEAT, FedCCA with differential privacy noise still outperforms the baselines running without any noise at all. This isn’t because the noise somehow helps; it’s because the underlying AMVM scheme starts from a much better position than alternating least squares, and the noise penalty doesn’t erase that advantage.

On MNIST with 10 clients, FedCCA uses 0.15 GFLOPs of computation compared to 28 GFLOPs for the distributed differentially private CCA method of Imtiaz and Sarwate (2019), and its communication cost drops from 9.28 MB to 5.11 MB. These gains come from avoiding the centralized SVD that prior methods required.

What it doesn’t do

All experiments use balanced, identically distributed data splits across clients. In practice, hospitals don’t have the same number of patients, and their patient populations may look very different. How AMVM behaves under such heterogeneity is an open question. The client counts tested go up to 10, which is far from the thousands or millions of devices in mobile federated learning. And the privacy accounting, while rigorous, is still at the level of individual rounds; a tighter composition analysis across the full protocol lifetime could change the practical noise requirements.

Additionally, FedCCA solves linear CCA while the deep learning community has largely moved to nonlinear variants, like deep CCA, that learn neural network transformations of each view before correlating them. Extending the von Neumann series trick to that setting is not straightforward, since the series relies on the auto-covariance matrices having a specific spectral structure that neural networks would disrupt.

Why it matters

The significance of this work lies in its methodological contribution to federated multi-view learning. By reformulating regularized CCA through a truncated von Neumann series, the paper replaces explicit matrix inversion and centralized covariance decomposition with alternating matrix-vector multiplications that are naturally compatible with client-side computation and low-dimensional communication.

This is important not merely as an implementation improvement, but because it changes the computational form of CCA in a way that makes the method deployable in privacy-constrained federated settings while retaining explicit control over approximation error and privacy-convergence trade-offs. In this sense, the paper offers a principled framework for adapting correlation-based spectral methods to distributed environments.

For practitioners in healthcare, telecommunications, and other fields where multi-modal data is valuable but siloed, FedCCA suggests that the choice between utility and privacy may be less stark than it appears. The gap between what you can learn with centralized data and what you can learn with federated data, at least for this family of problems, may be surprisingly small.

Related

thumbnail
Monday, May 18, 2026

Solving a fundamental problem in causal discovery

A new approach led by MBZUAI helps researchers identify which causal relationships can truly be recovered from.....

  1. ICLR ,
  2. machine learning ,
  3. research ,
  4. causal discovery ,
  5. conference ,
Read More
thumbnail
Tuesday, April 28, 2026

Commencement 2026: Embracing the opportunity, owning the responsibility

Abdulla Almansoori explains how he is using machine learning to give back to the institutions that made.....

  1. graduate ,
  2. Commencement 2026 ,
  3. commencement ,
  4. Ph.D. ,
  5. machine learning ,
Read More