- The paper introduces an optimal transport-based formulation that quantifies transfer learning’s sample complexity benefits over direct learning in high-dimensional settings.
- It establishes explicit L2 estimation error bounds, showing that smoother input distributions can mitigate the challenges posed by non-smooth regressors.
- Empirical tests in image classification and medical diagnosis validate that transfer learning significantly enhances performance when labeled target data is scarce.
Sample Complexity of Transfer Learning: An Optimal Transport Perspective
Introduction and Motivation
The paper "Sample Complexity of Transfer Learning: An Optimal Transport Approach" (2605.20545) introduces a rigorous framework for quantifying the impact of transfer learning on sample complexity in supervised settings, with a particular focus on high-dimensional tasks and modern deep learning scenarios. Traditional approaches to sample complexity analysis focus on direct learning rates, emphasizing the role of the smoothness of the regression function. This work takes a different path by leveraging optimal transport (OT) theory to analyze transfer learning, thus providing explicit minimax rates and revealing hitherto under-explored dependencies between distributional smoothness, model complexity, and dimension.
Transfer learning is evaluated in the context of scenarios where the source and target tasks are related, specifically when the source task offers either generic or domain-specific feature extractors pretrained on large datasets. Real-world applications in image classification and medical diagnosis—domains where data labeling is expensive—motivate the need for theoretically grounded understanding of sample efficiency gains afforded by transfer learning.
The authors frame supervised transfer learning as an optimal transport problem. The target is to estimate the regression function fT(x)=E[YT∣XT=x] for target data (XT,YT) from limited labeled examples. Transfer learning leverages a pretrained source model fS built on source data (XS,YS) with typically much greater coverage.
The transfer process is formalized using transfer maps TXS (input space) and TYS→T (output space), interpreted as optimal transport maps (i.e., Brenier maps), which minimize quadratic costs and align distributions across domains. The learning objective becomes:
TXS,TYS→TminE[ℓ(TYS→T∘fS∘TXS(XT),YT)]
For quadratic loss, convex optimal transport theory ensures existence, uniqueness (almost everywhere), and regularity of such transfer mappings under common smoothness and log-concavity assumptions.
By contrast, direct learning estimates fT nonparametrically from m samples of (XT,YT), and its sample complexity is dictated by the smoothness (XT,YT)0 of (XT,YT)1 and the data dimension (XT,YT)2.
Theoretical Results: Sample Complexity Bounds
The core contribution is the derivation of upper bounds on the (XT,YT)3 estimation error for the transfer learning estimator, expressed in terms of the smoothness (XT,YT)4 of the joint source and target data distributions and the ambient dimension (XT,YT)5. The main findings are:
- Direct Learning: Minimizing risk w.r.t. (XT,YT)6 using (XT,YT)7 samples yields error (XT,YT)8, i.e., the minimax nonparametric rate set by the regression function's smoothness (XT,YT)9 and dimensionality fS0.
- Transfer Learning (OT-based): With the OT formalism, the error for the transfer estimator behaves as fS1 (for fS2), where fS3 measures the (distributional) smoothness of the data; crucially, the regression function's smoothness fS4 drops out.
The optimality and tightness of these rates draw on recent statistical OT theory, with detailed dependence on transfer mapping regularity and log-concavity. In high dimensions (fS5), the rate separation between fS6 and fS7 can yield substantial improvements when fS8—a common situation when the underlying data is smooth but the regression function is non-smooth (e.g., due to non-smooth activations in deep models).
Theoretical implications are summarized as follows:
| Scenario |
Error Rate |
Key Smoothness Parameter |
| Direct Learning |
fS9 |
(XS,YS)0 (regressor smoothness) |
| Transfer Learning |
(XS,YS)1 |
(XS,YS)2 (distributional) |
When the input distributions are smooth (e.g., Gaussian mixtures), and the regressor (XS,YS)3 is non-smooth (e.g., deep ReLU nets), transfer learning achieves superior sample efficiency.
Numerical Experiments
Two experimental setups are used to validate the theoretical claims:
Image Classification (Office-31)
Transfer learning via ResNet-50 pretrained on ImageNet or source domains is compared with direct learning trained from randomly-initialized weights, across varying fractions of training data. Results in low-sample regimes (as little as 10% of data) show marked improvements in AUROC, accuracy, precision, and sensitivity—all metrics favor transfer learning, with relative improvements of over 100% in precision and sensitivity in the smallest data regime.
Medical Diagnosis: Retinopathy of Prematurity (ROP)
A secondary evaluation uses transfer from diabetic retinopathy (DR) diagnosis (large dataset) to retinopathy of prematurity (data-scarce, high-stakes). Transfer-learned classifiers substantially outperform direct learning at all data scales and exceed 0.9 AUROC and accuracy using only ~10% of available data. In the extreme low-data regime (1%), transfer learning yields sensitivity and precision gains exceeding 46% relative to direct learning.
These empirical findings confirm that transfer learning, interpreted through the OT framework, achieves significant improvements when target data is scarce and the regression function is non-smooth.
Implications and Future Directions
This work provides a precise statistical characterization of when and why transfer learning can yield marked sample complexity reductions. The results are most robust in high-dimensional settings where data distributions are smooth and the regression function is highly complex or non-smooth (deep models with ReLU activations, etc.).
Practical Implications
- Model selection: The findings offer formal guidance for practitioners: transfer learning is most advantageous for high-complexity models over smooth domains with scarce labeled data.
- Architecture design: Non-smooth activation functions, common in deep learning, align with the work's assumptions, justifying existing empirical heuristics.
- Medical and scientific AI: In high-stakes, data-limited settings—e.g., medical image diagnosis—the OT framework can justify transfer-learning-based pipelines that maximize sample efficiency.
Theoretical Implications and Open Problems
- OT map regularity: The impact of non-Lipschitz or non-convex transport maps (beyond log-concave distributions) remains an open problem.
- Distributional shift: Extension to negative transfer or unreliable source tasks can leverage OT-based ambiguity measures.
- Generalization beyond quadratic loss: The framework extends to other strictly convex costs, with implications for domain adaptation and unsupervised pretraining.
- Scalable computation of high-dimensional OT maps: Continued progress in fast, scalable OT (e.g., entropic or sliced variants) will facilitate broader practical adoption.
Conclusion
The paper provides a rigorous statistical foundation for the sample efficiency of transfer learning grounded in optimal transport, identifying precise conditions where transfer significantly outperforms direct learning. The results hold substantial consequence for the design and deployment of learning systems in data-constrained, high-dimensional applications. The OT-centric approach opens promising avenues for further advances in statistical learning theory and the principled development of data-efficient AI.