Skip to content

SimplexLab/TorchJD

Repository files navigation

Fallback image description

Doc Static Badge Tests codecov PyPI - Python Version Static Badge Static Badge

TorchJD is a PyTorch library for training neural networks with multiple losses. It supports two complementary approaches:

  • Scalarization: combine losses into a single scalar before backprop, using methods from the literature (geometric mean, softmax weighting, etc.). This is often a good baseline.
  • Jacobian descent: compute the Jacobian matrix of losses with respect to parameters and aggregate it into an update direction using state-of-the-art aggregators (UPGrad, MGDA, CAGrad, and many more). This in particular allows taking conflict-free optimization directions, which can resolve problems that may be impossible to solve with standard scalarizers.

The full documentation is available at torchjd.org.

Installation

TorchJD can be installed directly with pip:

pip install "torchjd[quadprog_projector]"

This includes the dependencies required by UPGrad and DualProj. Some other aggregators may have additional dependencies. Please refer to the installation documentation for them.

Usage

Scalarization

Scalarization methods combine losses into a single scalar before backprop. Here is how to change a standard training loop to use scalarization:

  import torch
  from torch.nn import Linear, MSELoss, ReLU, Sequential
  from torch.optim import SGD

+ from torchjd.scalarization import GeometricMean

  shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
  task1_module = Linear(3, 1)
  task2_module = Linear(3, 1)
  params = [*shared_module.parameters(), *task1_module.parameters(), *task2_module.parameters()]

  loss_fn = MSELoss()
  optimizer = SGD(params, lr=0.1)
+ scalarizer = GeometricMean()

  inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
  task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
  task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

  for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
      features = shared_module(input)
      loss1 = loss_fn(task1_module(features), target1)
      loss2 = loss_fn(task2_module(features), target2)

-     loss = loss1 + loss2
-     loss.backward()
+     loss = scalarizer(torch.stack([loss1, loss2]))
+     loss.backward()
      optimizer.step()
      optimizer.zero_grad()

Jacobian descent

Jacobian descent computes per-loss gradients individually and aggregates them into a single update direction. Some aggregators, like UPGrad, are specifically designed to find directions that are beneficial to all losses simultaneously. Here is how to change a standard multi-task training loop to use Jacobian descent:

  import torch
  from torch.nn import Linear, MSELoss, ReLU, Sequential
  from torch.optim import SGD

+ from torchjd.autojac import jac_to_grad, mtl_backward
+ from torchjd.aggregation import UPGrad

  shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
  task1_module = Linear(3, 1)
  task2_module = Linear(3, 1)
  params = [*shared_module.parameters(), *task1_module.parameters(), *task2_module.parameters()]

  loss_fn = MSELoss()
  optimizer = SGD(params, lr=0.1)
+ aggregator = UPGrad()

  inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
  task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
  task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

  for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
      features = shared_module(input)
      loss1 = loss_fn(task1_module(features), target1)
      loss2 = loss_fn(task2_module(features), target2)

-     loss = loss1 + loss2
-     loss.backward()
+     mtl_backward([loss1, loss2], features=features)
+     jac_to_grad(shared_module.parameters(), aggregator)
      optimizer.step()
      optimizer.zero_grad()

The autojac engine

The torchjd.autojac engine provides a way to compute Jacobians (generally of the losses with respect to the parameters). Its interface is very similar to that of torch.autograd: autojac.jac is analog to autograd.grad but returns Jacobians insteads of gradients, and autojac.backward is analog to autograd.backward but accumulates Jacobians in the .jac fields of parameters instead of gradients in the .grad fields (these Jacobians can then be aggregated into gradients and moved to the .grad fields by calling autojac.jac_to_grad). Lastly, the mtl_backward function can be used for multi-task learning to compute and accumulate gradients with respect to task-specific parameters and Jacobians with respect to shared parameters.

The autogram engine

TorchJD also provides the autogram engine, which computes the Gramian of the Jacobian incrementally without ever storing the full Jacobian in memory. This makes Jacobian descent feasible on large models where the full Jacobian would be too expensive to store. See the autogram examples for more details.

More usage examples, including instance-wise risk minimization and partial Jacobian descent, can be found in the docs.

Supported Scalarizers

Scalarizer Publication
Constant -
COSMOS COSMOS: Enhancing Multi-Objective Optimization with Scalarization
DWA End-to-End Multi-Task Learning with Attention
FAMO FAMO: Fast Adaptive Multitask Optimization
GeometricMean MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for Multi-Task Learning
IMTL-L Towards Impartial Multi-task Learning
Mean -
PBI A Decomposition-Based Evolutionary Algorithm for Many Objective Optimization
Random Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning
STCH Smooth Tchebycheff Scalarization for Multi-Objective Optimization
Sum -
UW Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics

Supported Aggregators and Weightings

TorchJD provides many existing aggregators from the literature, listed in the following table.

Aggregator Weighting Publication
UPGrad (recommended) UPGradWeighting Jacobian Descent For Multi-Objective Optimization
AlignedMTL AlignedMTLWeighting Independent Component Alignment for Multi-Task Learning
CAGrad CAGradWeighting Conflict-Averse Gradient Descent for Multi-task Learning
ConFIG - ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks
Constant ConstantWeighting -
- CRMOGMWeighting On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond
DualProj DualProjWeighting Gradient Episodic Memory for Continual Learning
ExcessMTL ExcessMTLWeighting Robust Multi-Task Learning with Excess Risks
FairGrad FairGradWeighting Fair Resource Allocation in Multi-Task Learning
GradDrop - Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout
GradVac GradVacWeighting Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models
IMTLG IMTLGWeighting Towards Impartial Multi-task Learning
Krum KrumWeighting Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
Mean MeanWeighting -
MGDA MGDAWeighting Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
- MoDoWeighting Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance
NashMTL - Multi-Task Learning as a Bargaining Game
PCGrad PCGradWeighting Gradient Surgery for Multi-Task Learning
Random RandomWeighting Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning
- SDMGradWeighting Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms
Sum SumWeighting -
Trimmed Mean - Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates

Release Methodology

We try to make a release whenever we have something worth sharing to users (bug fix, minor or large feature, etc.). TorchJD follows semantic versioning. Since the library is still in beta (0.x.y), we sometimes make interface changes in minor versions. We prioritize the long-term quality of the library, which occasionally means introducing breaking changes. Whenever a release contains breaking changes, the changelog and the GitHub release notes always include clear instructions on how to migrate.

Contribution

Please read the Contribution page and join our Discord to get involved!

Thanks to our amazing contributors for making this project possible:

Citation

If you use TorchJD for your research, please cite:

@article{jacobian_descent,
  title={Jacobian Descent For Multi-Objective Optimization},
  author={Quinton, Pierre and Rey, Valérian},
  journal={arXiv preprint arXiv:2406.16232},
  year={2024}
}

About

Library for Jacobian descent with PyTorch. It enables the optimization of neural networks with multiple losses (e.g. multi-task learning).

Topics

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Contributors

Languages