on
Gradient Equivalence in Siamese Self-Supervised Learning
Jan 3, 2022 by Matthieu Lin.
Following [1], we derive the gradient of different siamese self-supervised learning methods and show that although these methods appear to be quite different, they have similar gradient formulas. In particular, the gradient consists of three terms:
- A positive gradient, i.e., the representation of another augmented view from the same image, pulls positive samples together.
- A negative gradient, i.e., a weighted combination of the representations from different images, pushes negative samples apart.
- A balancing factor weights the two terms.
Under this assumption, [1] empirically shows that these methods' performances are similar, and only the momentum encoder dramatically improves the final performance.
1. Introduction
1.1 What is "Self" in Self-Supervised Learning?
The world self refers to the ability to generate labels from the data by leveraging the underlying structure in the data. For instance, it could be predicting a missing word in a sentence. Because of the natural structure of human language, solving this task requires a high-level understanding of the sentence. From this, we hope that our model extracts a generic representation to solve the task.
1.2 What is "Siamese" in Siamese Self-Supervised Learning?
A siamese network is composed of two branches: an online branch and a target branch, where the target branch shares weights with the online branch or keeps an exponential moving average of the online branch. In particular, given two different augmented views of the input image, each branch computes an augmented view, and the output of the target branch serves as the training target for the online branch. The loss maximizes the agreement between the two augmented views, i.e., we want our network to be invariant to image augmentation. We hope the online branch learns a generic representation transferrable to downstream tasks by optimizing this objective. It is important to note that these methods usually rely on heavy, hand-engineered data augmentation.
1.3 What is the goal of Self-Supervised Learning?
Self-Supervised Learning acts as a pre-text task where the goal is to learn a generic representation transferrable to downstream tasks. In contrast to previous methods that pre-train a network on annotated datasets, self-supervised learning methods do not require labels.
Note: Although self-supervised learning techniques do not rely on labels, it relies on the highly curated ImageNet dataset, e.g., images usually contain a single object at the center.
When transferring to downstream tasks, we use the online branch. For instance, those downstream tasks can be classification, object detection, or semantic segmentation. We evaluate the quality of the learned representation of the online network on annotated datasets, where we either fine-tune the network or train a classifier on top of frozen features. For fair comparisons, these methods use the ImageNet dataset and a randomly initialized ResNet50. A good representation makes the classes linearly separable.
1.4 Key Concepts
A trivial solution exists for these siamese networks, where the network outputs the same embedding for all images to minimize the loss. When the network learns this trivial solution, we call this feature collapse.
We can roughly split current self-supervised methods into three frameworks, where each proposes different ways to prevent this collapse:
- Contrastive Learning methods (e.g., MoCo [3], SimCLR [2]) contrast positive samples with negative samples to prevent collapse. Specifically, the target branch outputs the representation of a positive sample and a set of negative samples, and the loss explicitly pulls the pair of positive samples together while pushing apart the pair of negative samples.
- Asymmetric networks (e.g., BYOL [5], SimSiam [4]) only rely on positive samples. They introduce an asymmetry between the target and the online branch to prevent collapse. In this case, the loss explicitly pulls together the positive sample pairs.
- Feature decorrelation methods (e.g., Barlow Twins [6]) prevent collapse by pushing the cross-correlation matrix of the two views close to the identity matrix. Optimizing this objective makes each feature invariant under data augmentation while reducing the redundancy between each feature.
2. Equivalence of Gradients
Notations:
2.1 Contrastive Methods
2.1.1 MoCo:
Method.
The contributions of this paper are of two folds. First, they introduce a queue of negative samples instead of using a big mini-batch. Second, the target branch uses a momentum encoder, e.g.,
Let
The loss function can be thought of as a
Gradient.
Let
Let
Since we do not backpropagate through the target branch, we only receive gradient from
where,
and if
else,
Hence,
From this, we observe two things: (1) the first term pulls positive samples together, (2) the second term pushes negative samples apart.
2.1.2 SimCLR:
Method.
This paper introduces four ingredients that substantially improve the learned representation. It consists of (1) a non-linear projection on top of the encoder during training of the pre-text task, (2) a set of heavy data augmentation, (3) a cosine similarity function between pairs, and (4) an extended training schedule.
Unlike MoCo, in SimCLR, the target and the online branch share the same weights. Therefore, it is easier to think about the two branches' output as a single variable
Gradient.
As in MoCo, we assume the target branch computes the positive and the negative samples. Thus, by stopping the gradient on the target branch, those two methods are equivalent. In [1], they empirically verified this.
Let
where,
and
and
Then the gradient w.r.t
where the first term is
and by observing that all
and by observing that all the
Hence,
And if we stop the gradient through the target branch, then the second term vanishes and
This gradient is similar to MoCo's, i.e., the first term pulls positive samples together and the second term pushes negative samples apart.
2.2 Asymmetric Methods
2.2.1 BYOL
Practically, we can think of BYOL as SimSiam with a momentum encoder.
Method.
Simsiam introduces an asymmetric architecture with a stop gradient on the target branch to prevent collapse. In particular, it appends a predictor
Gradient.
Let
The authors of [7] propose an analytical solution for the predictor
Motivated by this, they directly set
where
and,
Note that the authors of [7] stop the gradient through the target branch.
Let
thus,
Let
Empirically, the second term can be safely removed, giving
By further observing that
At first glance, it seems counter-intuitive that the gradient is also a combination of positive and negative samples since no negative samples appear in the loss function explicitly. However, the derived gradient formula suggests that the weights of the feature correlation matrix encode the negative samples. Specifically, [7] suggests that the eigenspace of the predictor
2.3 Feature Decorrelation Methods
2.3.1 Barlow Twins
Compared to the previous method, Barlow Twins does not require a large batch, asymmetric architecture, gradient stopping, or moving average on the weight update.
Method.
Inspired by Horace Barlow's efficient coding hypothesis, this paper proposes to reduce redundancy instead of maximizing similarity. In particular, we want each neuron to satisfy (1) invariance under data augmentation (2) independence to other neurons, e.g., reduce redundancy. Property one means that the neurons behave the same way for different data augmentation and property two means that all the neurons should be different. By forcing each neuron to be different, it prevents feature collapse where all the neurons are the same. This process is equivalent to pushing the cross-correlation matrix close to the identity matrix. In particular, given the cross-correlation matrix as the sum of outer products:
so,
and the objective function is:
The first term is the invariance term, and the second term is the redundancy reduction term. The parameter
Note: there might be another type of collapse where the networks satisfy the two properties, e.g., invariance and redundancy reduction, but it outputs representations that are constant across the batch dimension. To prevent this, we standardize the output; if the representation across the batch dimension is the same, C would be a zero matrix.
Intriguingly, Barlow Twins keeps improving from higher dimensions even when
Gradient.
Let
Then gradient of
and the gradient of
Hence,
Let
Similar to other methods, the first term is the positive sample
Note: the authors of [1] empirically show that removing
3. Conclusion
The authors of [1] show some interesting insights about the success of siamese Self-Supervised Learning. In particular, they find that:
- Increasing the depth of the projector from 1 to 3 boosts the linear evaluation accuracy significantly.
- Increasing the projector's width boosts the performance and does not seem to saturate even when the dimension increases to 16384.
- Only a consistent and slow updating positive key is enough and essential for self-supervised learning. Hence, contrary to findings from MoCo's, the slowly updating memory bank of negative samples is unnecessary.
- The representation learned by these different methods works similarly. SimCLR and BYOL also learn to decorrelate different channels, and Barlow Twins can learn to discriminate between positive and negative samples. This result supports the claim that these methods have similar gradient formula.
Although siamese self-supervised methods have shown remarkable performance, these methods still rely on hand-crafted invariance (data augmentation). Therefore, those methods may greatly benefit from learned data augmentation.
4. Reference
[1] Exploring the Equivalence of Siamese Self-Supervised Learning Via A Unified Gradient Framework.
[2] A simple Framework for Contrastive Learning of Visual Representations.
[3] Momentum Contrast for Unsupervised Visual Representation Learning.
[4] Exploring Simple SIamese Representation Learning.
[5] Bootstrap Your Own Latent A New Approach to Self-Supervised Learning.
[6] Barlow Twins: Self-Supervised Learning via Redundancy Reduction.
[7] Understanding Self-Supervised Learning Dynamics without Contrastive Paris.