self supervised learning

Gradient Equivalence in Siamese Self-Supervised Learning

Jan 3, 2022 by Matthieu Lin.

Following [1], we derive the gradient of different siamese self-supervised learning methods and show that although these methods appear to be quite different, they have similar gradient formulas. In particular, the gradient consists of three terms:

  1. A positive gradient, i.e., the representation of another augmented view from the same image, pulls positive samples together.
  2. A negative gradient, i.e., a weighted combination of the representations from different images, pushes negative samples apart.
  3. A balancing factor weights the two terms.

Under this assumption, [1] empirically shows that these methods' performances are similar, and only the momentum encoder dramatically improves the final performance.

1. Introduction

1.1 What is "Self" in Self-Supervised Learning?

The world self refers to the ability to generate labels from the data by leveraging the underlying structure in the data. For instance, it could be predicting a missing word in a sentence. Because of the natural structure of human language, solving this task requires a high-level understanding of the sentence. From this, we hope that our model extracts a generic representation to solve the task.

1.2 What is "Siamese" in Siamese Self-Supervised Learning?

A siamese network is composed of two branches: an online branch and a target branch, where the target branch shares weights with the online branch or keeps an exponential moving average of the online branch. In particular, given two different augmented views of the input image, each branch computes an augmented view, and the output of the target branch serves as the training target for the online branch. The loss maximizes the agreement between the two augmented views, i.e., we want our network to be invariant to image augmentation. We hope the online branch learns a generic representation transferrable to downstream tasks by optimizing this objective. It is important to note that these methods usually rely on heavy, hand-engineered data augmentation.

1.3 What is the goal of Self-Supervised Learning?

Self-Supervised Learning acts as a pre-text task where the goal is to learn a generic representation transferrable to downstream tasks. In contrast to previous methods that pre-train a network on annotated datasets, self-supervised learning methods do not require labels.

Note: Although self-supervised learning techniques do not rely on labels, it relies on the highly curated ImageNet dataset, e.g., images usually contain a single object at the center.

When transferring to downstream tasks, we use the online branch. For instance, those downstream tasks can be classification, object detection, or semantic segmentation. We evaluate the quality of the learned representation of the online network on annotated datasets, where we either fine-tune the network or train a classifier on top of frozen features. For fair comparisons, these methods use the ImageNet dataset and a randomly initialized ResNet50. A good representation makes the classes linearly separable.

1.4 Key Concepts

A trivial solution exists for these siamese networks, where the network outputs the same embedding for all images to minimize the loss. When the network learns this trivial solution, we call this feature collapse.

We can roughly split current self-supervised methods into three frameworks, where each proposes different ways to prevent this collapse:

  1. Contrastive Learning methods (e.g., MoCo [3], SimCLR [2]) contrast positive samples with negative samples to prevent collapse. Specifically, the target branch outputs the representation of a positive sample and a set of negative samples, and the loss explicitly pulls the pair of positive samples together while pushing apart the pair of negative samples.
  2. Asymmetric networks (e.g., BYOL [5], SimSiam [4]) only rely on positive samples. They introduce an asymmetry between the target and the online branch to prevent collapse. In this case, the loss explicitly pulls together the positive sample pairs.
  3. Feature decorrelation methods (e.g., Barlow Twins [6]) prevent collapse by pushing the cross-correlation matrix of the two views close to the identity matrix. Optimizing this objective makes each feature invariant under data augmentation while reducing the redundancy between each feature.

2. Equivalence of Gradients

Notations:

x we use underline to denotes vectors.

x(n) and (x(n))i refers to the n-th element in a batch and the i-th element of this vector.

X we use capital letters to denote matrices.

denotes the gradient operator.

<,> denotes the inner product operator.

2.1 Contrastive Methods

2.1.1 MoCo:

q denote the output of the online branch.

k denote the output of the target branch and k(+) denotes the corresponding positive sample to q in the batch.

θk are the parameters of the target branch.

θq are the parameters of the online branch.

Method.

The contributions of this paper are of two folds. First, they introduce a queue of negative samples instead of using a big mini-batch. Second, the target branch uses a momentum encoder, e.g., θk is an exponential moving average of θq. Hence, the two branches are different.

Let τ be the temperature hyper-parameter. Then, We define the loss for a single pair as:

L(q)=logexp(<q,k(+)>/τ)i=0Kexp(<q,k(i)>/τ).

The loss function can be thought of as a K+1 way softmax. It pulls q and k(+) together while pushing q and k(i) apart.

Gradient.

Let B be the set of elements in the memory bank. Then, we define the loss function as:

L=1Nn=1NL(q(n)),

Let zj(n) be the softmax of the similarity j-th key with the query. Then, if the positive key associated to q(n), denoted as k(+), is at index 0, then we can write the cross-entropy loss as:

L(q(n))=j=0K1j=0log(zj(n))
zj(n)=exp(sj(n))mBexp(Sm(n)),
sm(n)=<q(n),k(m)>/τ

Since we do not backpropagate through the target branch, we only receive gradient from q(n) and the gradient w.r.t to q(n) is:

(q(n)L)i=1Nn=1Nj=0KmBLzj(n)zj(n)sm(n)sm(n)(q(n))i=1N(Lz0(n)z0(n)s0(n)s0(n)(q(n))i+mB{k(0)}Lz0(n)z0(n)sm(n)sm(n)(q(n))i),

where,

Lzj(n)=1zj(n),sm(n)(q(n))i=1τ(k(m))i,

and if mj,

zj(n)sm(n)=sm(n)exp(sj(n))mBexp(sm(n))=exp(sj(n))exp(sm(n))(mBexp(sm(n)))2=zj(n)zm(n),

else,

zj(n)sj(n)=sj(n)exp(sj(n))mBexp(sm(n))=exp(sj(n))mBexp(sm(n))exp(sj(n))exp(sj(n))(mBexp(sm(n)))2=zj(n)(1zj(n)).

Hence,

(q(n)L)i=1N((z0(n)1)(k(0))iτ+mB{k(0)}zm(n)(k(m))iτ)=1N((k(0))iτ+mBzm(n)(k(m))iτ)q(n)L=1N(k(0)τ+mBzm(n)k(m)τ),

From this, we observe two things: (1) the first term pulls positive samples together, (2) the second term pushes negative samples apart.

2.1.2 SimCLR:

Method.

This paper introduces four ingredients that substantially improve the learned representation. It consists of (1) a non-linear projection on top of the encoder during training of the pre-text task, (2) a set of heavy data augmentation, (3) a cosine similarity function between pairs, and (4) an extended training schedule.

Unlike MoCo, in SimCLR, the target and the online branch share the same weights. Therefore, it is easier to think about the two branches' output as a single variable z(n) instead of q(n) and k(n). In particular, given an image, the loss pushes those two views to be similar while dissimilar to other 2N2 images in the batch.

Gradient.

As in MoCo, we assume the target branch computes the positive and the negative samples. Thus, by stopping the gradient on the target branch, those two methods are equivalent. In [1], they empirically verified this.

Let n+ denote the positive sample associated with a sample n. Because SimCLR does not stop the gradient on the target branch, z(n) receive gradient from all other samples in the batch. Given the loss:

L=1Nn=1NL(z(n)),

where,

L(z(n))=log(sn+(n)),

and

sn+(n)=expcn+(n)mB{z(n)}expcm(n),

and

cm(n)=<z(n),z(m)>/τ.

Then the gradient w.r.t z(n) is:

(z(n)L)i=k=1NmB{z(k)}Lsk+(k)sk+(k)cm(k)cm(k)(z(n))i=mB{z(n)}Lsn+(n)sn+(n)cm(n)cm(n)(z(n))i+mB{z(n+)}Lsn(n+)sn(n+)cm(n+)cm(n+)(z(n))i+kn,n+mB{z(k)}Lsk+(k)sk+(k)cm(k)cm(k)(z(n))i,

where the first term is

mB{z(n)}Lsn+(n)sn+(n)cm(n)cm(n)(z(n))i=1τN[(z(n+))i+mB{zn}sm(n)(zm)i],

and by observing that all mn terms are 0, the second term is

mB{z(n+)}Lsn(n+)sn(n+)cm(n+)cm(n+)(z(n))i=1τN[(zn+)i+sn(n+)(zn+)i],

and by observing that all the mn terms are 0, the term is

kn,n+mB{z(k)}Lsk+(k)sk+(k)cm(k)cm(k)(z(n))i=1τN[kn,n+sn(k)(z(k))i].

Hence,

(z(n)L)i=1τN[(z(n+))i+mB{zn}sm(n)(zm)i]+1τN[(zn+)i+knsn(k)(z(k))i].

And if we stop the gradient through the target branch, then the second term vanishes and

z(n)L=1τN[z(n+)+mB{zn}sm(n)zm].

This gradient is similar to MoCo's, i.e., the first term pulls positive samples together and the second term pushes negative samples apart.

2.2 Asymmetric Methods

2.2.1 BYOL

Practically, we can think of BYOL as SimSiam with a momentum encoder.

Method.

Simsiam introduces an asymmetric architecture with a stop gradient on the target branch to prevent collapse. In particular, it appends a predictor h() to the encoder of the online branch f, and it stops the gradient through the target branch. Furthermore, Simsiam and BYOL implement the predictor h() using a multi-layer perceptron. Finally, the objective function BYOL pushes the two augmented views of an image to be similar; hence it does not require a large batch or any negative sample pairs. In particular, the objective function is the negative cosine similarity between the two augmented views. For example, given the output of the predictor p1 and the output of the encoder z2, the objective function is:

L=<p1p12,z2z22>.
Gradient.

Let B be the set of all previous samples, and ρz be the moving average weight for each sample according to their batch order.

The authors of [7] propose an analytical solution for the predictor h(). In particular, they claim that the eigenspace of the predictor Wh gradually aligns with that of the feature correlation matrix F. The feature correlation matrix is obtained via a moving average, i.e.,

F=zBρzz1z1T.

Motivated by this, they directly set Wh to be a function of F. In particular, Wh requires the SVD decomposition, i.e.,

Wh=UΛhUT,Λh=ΛF1/2+ϵλmaxI,

where U and Λh are the eigenvectors and eigenvalues of F, with λmax the max eigenvalue of F and ϵ is a hyper-parameter to help boost small eigenvalues. Equipped with this analytical solution, we rewrite the loss as:

L=1Nb=1NL(p1(b)),

and,

L(p1(b))=<p1(b),z2(b)>,wherez2(b)=z2(b)z2(b)2andp1(b)=p1(b)p1(b)2

Note that the authors of [7] stop the gradient through the target branch.

Let p1(b)=Whz1(n) with p1(b)Rdp and p1(b)Rd, then the gradient w.r.t to z1(n) is:

(zi(n)L)i=1Nb=1Nj=1dL(p1(b))jk=1dp(p1(b))j(p1(b))k(p1(b))k(z1(n))i=1Nj=1dL(p1(n))jk=1dp(p1(n))j(p1(n))k(p1(n))k(z1(n))i=1Nj=1d(z2(n))jk=1dp1p1(n)2(1[j=k](p1(n))j(p1(n))kp1(n)22)(Wh)ki,

thus,

zi(n)L=1N(p1(n))Tz(n)(p1(n))T(p1(n))L(p1(n))=1N[WhT1p1(n)2(Ip1(n)p1(n)Tp1(n)Tp1(n))z2(n)]wherep1(n)=Whz1(n)=1Whz1(n)2N[WhT(IWhz1(n)z1(n)TWhTz1(n)TWhTWhz1(n))z2(n)]=1Whz1(n)2N[WhTz2(n)+WhTWhz1(n)z1(n)TWhTz1(n)TWhTWhz1(n)z2(n)]=1Whz1(n)2N[WhTz2(n)+WhTWhz1(n)z1(n)TWhTz2(n)z1(n)TWhTWhz1(n)],

Let λ=z1(n)TWhTz2(n)z1(n)TWhTWhz1(n)R, we know that

WhTWh=UΛhUTUΛhUT,=UΛhΛhUT,=U(ΛF+2ϵλmaxΛF1/2+ϵ2λmax2I)UT=F+2ϵλmaxF1/2+ϵ2λmax2I.

Empirically, the second term can be safely removed, giving

zi(n)L=1Whz1(n)2N(WhTz2(n)+λ(Fz1(n)+ϵ2λmax2z1(n)))

By further observing that z1(n) is l2 normalized, we neglect ϵ2λmax2 as the component of this gradient along the direction of z1(n) will have no effect. Hence, we re-write the gradient w.r.t z1(n) as:

zi(n)L=1Whz1(n)2N(WhTz2(n)+λFz1(n))=1Whz1(n)2N(WhTz2(n)+λ(bBρzz1(b)z1(b)T)z1(n)).

At first glance, it seems counter-intuitive that the gradient is also a combination of positive and negative samples since no negative samples appear in the loss function explicitly. However, the derived gradient formula suggests that the weights of the feature correlation matrix encode the negative samples. Specifically, [7] suggests that the eigenspace of the predictor Wh will gradually align with that of the feature correlation matrix F.

2.3 Feature Decorrelation Methods

2.3.1 Barlow Twins

q denotes the standardized output from the online branch.

k denotes the standardized output from the target branch.

CRD×D is the cross-correlation matrix where D is the feature dimension.

Compared to the previous method, Barlow Twins does not require a large batch, asymmetric architecture, gradient stopping, or moving average on the weight update.

Method.

Inspired by Horace Barlow's efficient coding hypothesis, this paper proposes to reduce redundancy instead of maximizing similarity. In particular, we want each neuron to satisfy (1) invariance under data augmentation (2) independence to other neurons, e.g., reduce redundancy. Property one means that the neurons behave the same way for different data augmentation and property two means that all the neurons should be different. By forcing each neuron to be different, it prevents feature collapse where all the neurons are the same. This process is equivalent to pushing the cross-correlation matrix close to the identity matrix. In particular, given the cross-correlation matrix as the sum of outer products:

C=1Nb=1Nq(b)k(b)T,

so,

Cij=1Nb=1N(q(b))i(k(b)T)j,

and the objective function is:

L=d=1D(CddIdd)2+λd=1DjiCdj2,

The first term is the invariance term, and the second term is the redundancy reduction term. The parameter λ balances the two terms since the second term has D2D element while the first one has only D.

Note: there might be another type of collapse where the networks satisfy the two properties, e.g., invariance and redundancy reduction, but it outputs representations that are constant across the batch dimension. To prevent this, we standardize the output; if the representation across the batch dimension is the same, C would be a zero matrix.

Intriguingly, Barlow Twins keeps improving from higher dimensions even when D=16384.

Gradient.

Let L1=d=1D(Cdd1)2 and L2=d=1DjdCdj2 then L=L1+λL2.

Then gradient of L1 w.r.t q(n) is:

(q(n)L1)i=L1CiiCii(q(n))i=2(CiiIii)(k(n)T)iN,

and the gradient of L2 w.r.t q(n) is:

(q(n)L2)i=λjiDL2CijCij(q(n))i=λjiD2Cij(k(n)T)jN=2λN[Cii(k(n)T)i+j=1DCij(k(n)T)j]=2λN[Cii(k(n)T)i+j=1D1Nb=1N(q(b))i(k(b)T)j(k(n)T)j]=2λN[Cii(k(n)T)i+1Nb=1N(q(b))ij=1D(k(b)T)j(k(n)T)j]=2λN[Cii(k(n)T)i+b=1Nk(n)Tk(b)N(q(b))i],

Hence,

(q(n)L)i=2N[(k(n)T)i(CiiIiiλCii)+λb=1Nk(n)Tk(b)N(q(b))i]=2N[(k(n)T)i(IiiCii(1λ))+λb=1Nk(n)Tk(b)N(q(b))i].

Let A=(I(1λ)Cdiag), where (Cdiag)ij=δijCij and δij is the Kronecker delta, then

(q(n)L)=2N[Ak(n)+λb=1Nk(n)Tk(b)Nq(b)].

Similar to other methods, the first term is the positive sample k(n) and the second term is the negative samples q(b).

Note: the authors of [1] empirically show that removing A in the first term does not harm the performance. Furthermore, replacing batch normalization with l2 normalization on the representations kandq does not the performance either.

3. Conclusion

The authors of [1] show some interesting insights about the success of siamese Self-Supervised Learning. In particular, they find that:

  • Increasing the depth of the projector from 1 to 3 boosts the linear evaluation accuracy significantly.
  • Increasing the projector's width boosts the performance and does not seem to saturate even when the dimension increases to 16384.
  • Only a consistent and slow updating positive key is enough and essential for self-supervised learning. Hence, contrary to findings from MoCo's, the slowly updating memory bank of negative samples is unnecessary.
  • The representation learned by these different methods works similarly. SimCLR and BYOL also learn to decorrelate different channels, and Barlow Twins can learn to discriminate between positive and negative samples. This result supports the claim that these methods have similar gradient formula.

Although siamese self-supervised methods have shown remarkable performance, these methods still rely on hand-crafted invariance (data augmentation). Therefore, those methods may greatly benefit from learned data augmentation.

4. Reference

[1] Exploring the Equivalence of Siamese Self-Supervised Learning Via A Unified Gradient Framework.

[2] A simple Framework for Contrastive Learning of Visual Representations.

[3] Momentum Contrast for Unsupervised Visual Representation Learning.

[4] Exploring Simple SIamese Representation Learning.

[5] Bootstrap Your Own Latent A New Approach to Self-Supervised Learning.

[6] Barlow Twins: Self-Supervised Learning via Redundancy Reduction.

[7] Understanding Self-Supervised Learning Dynamics without Contrastive Paris.