RotoGrad - Gradient Homogenization in Multitask Learning
Authors: Adrian Javaloy & Isabel Valera from Saarland University
Submitted to ICLR2022, score: 8, 8, 8, 8.
Notations:
denotes the gradient operator.
underline to denote vectors.
capital letters denote matrices.
dot product operator.
the special orthogonal group.
, i.e., . denoted as the hat oprator.
the associated lie algebra to constitute a set of skew-symmetric matrices.
the inverse of the hat operator denoted as the vee operator.
1. What did the authors try to accomplish?
Rotograd tackles the problem of negative transfer in multi-task learning (MTL). In particular, negative transfer in MTL is caused by:
Varying gradient magnitude across tasks, e.g., one task with a relatively higher gradient magnitude will dominate the gradient's direction resulting in poor performance of other tasks.
Varying gradient direction across tasks, e.g., gradients from different tasks, may cancel out each other, resulting in slow task learning.
1.1 Main Claims
On the one hand, they propose to guide the step size with the least converged task. On the other hand, a task-specific rotation matrix rotates the feature space to align the gradients according to a given direction.
1.2 Contributions
The authors claim that previous work in the literature primarily focused on the first problem. Instead, they propose to tackle both problems. Specifically, they propose a novel way to tackle gradient direction conflict by rotating the feature space such that the gradients align towards a common direction. To the best of my knowledge, they are the first to use rotation transformation in the Lie group to multi-task learning.
2. What were the key element of the approach?
2.1 Idea
Given an input , let be the shared backbone for all tasks and the task-specific head. Moreover, given the linear combination of task gradients , we have by the chain rule. Since the tasks only compete for ressource on , we can ignore and focus on , where is the task-specific loss. Furthermore let be the head for each task.
2.1.1 Gradient Magnitude
Goal change the magnitude of each .
Let be the gradient for the -th task and the -th data point. Then, its batch version is .
The authors propose to normalize each gradient to a unit vector. Then they rescale it such that the tasks that have converged the least set up the step size. We set step size to be a weighted combination of task-wise gradient combination, where the weights sum up to one. Let be this weight for each task, then
A side effect of this is that slow converging tasks will force quick converging tasks to escape from saddle points.
2.1.2 Gradient Direction
Let be the task-specific rotation matrix, then instead of optimizing , we optimize where . By rotating the feature space we can ensure that gradients direction from different tasks do not conflict. Instead, they align to the same direction given by . In particular, has to minimize the following objective:
where is a sum of normalized gradient direction from the different tasks, and is the gradient flowing down from the head before being rotated.
Learning a task-specific rotation matrix requires considering a constrained minimization problem for . In practice, is of high dimension; hence optimizing this objective is unfeasible as it would require computing the determinant of , i.e., we want . Hence, to turn it into an unconstrained minimization problem, we consider that allows us to work with rotational matrices in a vector space . A vector space of matrices is much nicer to work with since it is closed under vector summation and scalar multiplication. We do not need to keep the constraint and use regular gradient descent to optimize .
One might think that we can directly optimize the free parameters of the rotation matrices, however, constructing such a rotation matrix in high dimension is unfeasible.
Lie Group and Lie Algebra
On a high level, the set of all -dimensional rotational matrices (with determinant 1) constitute a group as it respects the four axioms (i.e., identity, associative, inverse, and closure) under matrix multiplication. In addition, if these rotational matrices are differential, it is also a Lie group. For the special orthogonal matrices group denoted as , one can smoothly rotate a matrix into another, hence it is a connected Lie group. This group is "special" because these matrices' determinant are 1.
We consider because every Lie group has an associated Lie algebra. In particular, the Lie algebra denoted is the set of skew-symmetric matrices that form a vector space. These skew-symmetric matrices can be mapped back to their corresponding element on the Lie group via exponential maps . And this new space allows us to work with a local approximation of the rotation matrices .
More formally, consider a set rotation matrices , which continuously transform a point from its original location () to a different one:
since , , we have
then,
We know that, skew-symmetric matrices or anti-symmetric matrices have the two following properties and . The diagonal must be 0 as and .
Let be the skew-symmetric matrix of a vector 𝕕. Then the above results suggest that there exists a vector such that:
and since , it follows that . This is a simple ordinary differential equation, for simplicity assume constant in , then the solution is the matrix exponential:
yielding the exponential maps ; . In fact, every rotation matrix have infinitely many exponential coordinates such that . The exponential coordinates provide a local parameterization for rotation matrices.
As a sanity check, since is skew-symmetric, it yields that , hence:
Furthermore, we observe that the skew-symmetric gives the first-order approximation of a rotation at :
Hence, the Lie algebra uses the tangent space to the Lie group at the identity element . We also observe that .
To conclude, we showed that:
Note: The authors have tried using a learned non-linear transformation instead of a rotation matrix. However, such choice results in numerical issues related to scaling of the feature in the forward pass, i.e., it affects the effective learning rate of different heads.
2.2 Limitations
The method scales linearly with the number of tasks, i.e., induces additional parameters. And the time complexity induced is .