Processing math: 25%

[Paper Review] Neural Discrete Representation Learning

2022. 8. 2. 12:59Paper Review

1. Introduction

동기(Motivation)

 

challenging tasks such as few-shot learning, domian adaptation, or reinforcement learning heavily rely on learnt representations from raw data, but the usefulness of generic representations trained in an unsupervised fashion is still far from being the dominant approach.
Maximum likelihood and reconstruction error are two common objectives used to train unsupervised models in the pixel domain, however their usefulness depends on the particular application the features are used in.

 

여러 도전적인 과제는 학습된 표현(=learnt representation)을 필요로 한다. 하지만 비지도 학습(=ML or reconstruction error)으로 얻은 표현은 사용성 (or 유용성)면에서 한계가 있다.

 

목표(Objective)

Our goal is to achieve a model that conserves the important features of the data in its latent space while optimising for maximum likelihood.

 

논문 목표는 ML(=비지도 학습)로 최적화하는 동시에 잠재 공간에서 데이터의 중요 성질 (or 표현)을 유지하는 것이다.


Learning representations with continuous features have been the focus of many previous work however we concentrate on discrete representations which are potentially a more natural fit for many of the modalities we are interested in.

 

지난 연구들에서는 연속적인 표현(=continuous feature)을 학습 했다. 하지만 이번 연구에서는 많은 양식에 더 잘 어울리는 이산적인 표현(=discrete representation)을 학습하기로 했다.

 

In our work, we introduce a new family of generative models successfully combining the variational autoencoder (VAE) framework with discrete latent representations through a novel parameterisation of the posterior distribution of (discrete) latents given an observation.

 

이번 논문에서는 (사후 분포의 새로운 매개변수화를 통해) VAE와 이산 잠재 변수를 결합한 새로운 생성 모델들을 소개한다.

 

Our model, which relies on vector quantization (VQ), is simple to train, does not suffer from large variance, and avoids the "posterior collapse" issue which has been problematic with many VAE models that have a powerful decoder, often caused by latents being ignored.

 

소개한 생성 모델은 "posterior collapse" 문제(=강력한 디코더로 인해 잠재 변수가 무시되는 경우)가 거의 없다.

 

Lastly, once a good discrete latent structure of a modality is discovered by the VQ-VAE, we train a powerful prior over these discrete random variables, yielding interesting samples and useful applications.

 

VQ-VAE 학습을 통해 적절한 이산 잠재 구조를 찾은 다음, 강력한 사전 분포를 학습시킨다. 이를 가지고 흥미로운 샘플을 생성하거나 유용한 응용 프로그램을 만드는데 사용할 수 있다.

 

2.VQ-VAE

In this work, we introduce the VQ-VAE where we use discrete latent variables with a new way of training, inspired by vector quantisation (VQ). The posterior and prior distributions are categorical, and the samples drawn from these distributions index an embedding table.

 

VQ-VAE는 이산 잠재 공간을 사용함과 동시에 새로운 학습 방법을 제시한다. VAE와 다르게 사후 분포(p(z|x))와 사전 분포(p(z))는 이산 분포 형태를 띈다. 이 분포에서 샘플링한 z로 임베딩 테이블 e 인덱싱(indexing)을 한다. 

 

2.1. Discrete Latent variables

We define a latent embedding space eRK×D where K is the size of the discrete latent space (i.e., a K-way categorical), and D is the dimensionality of each latent embedding vector ei. As shown in Figure 1, the model takes an input x, that is passed through an encoder producing output ze(x). The discrete latent variables z are then calculated by a nearest neighbour look-up using the shared embedding space e as shown in equation 1. The input to the decoder is the corresponding embedding vector ek as given in equation 2.

 

잠재 임베딩 공간를 RK×D (K=dim of discrete latent space, D=dim of embedding vector)로 정의 했다. (1) ze(x): 인코더 입력 x를 인코더에 넣어 나온 결과. (2) q(z|x): ze(x)(인코더 출력)와 가장 가까운 임베딩 벡터의 인덱스(=z). (3) zq(x): 임베딩 테이블 z번째 벡터. (4) p(x|z): 디코더 입력 zq(x)를 디코더에 넣어 나온 결과.

 

The posterior categorical distribution q(z|x) probabilities are defined as one-hot as follows: q(z=k|x)={1fork=argmin(1)
where z_e(x) is the output of the encoder network.

 

z_q(x) = e_k, where k = \arg\min_j \lVert z_e(x) - e_j \rVert_2\quad(2)

 

사후 이산 분포(q(z|x))와 디코더 입력(z_q(x))을 위 수식과 같이 정의한다.

 

글쓴이 뇌피셜

VQ-VAE를 학습 시 수식 2를 사용한다. 사전 분포(p(z))를 학습 시 수식 1를 사용한다.

 

2.2. Learning

Note that there is no real gradient defined for equation 2, however we approximate the gradient similar to the straight-through estimator and just copy gradients from decoder input z_q(x) to encoder output z_e(x).

Since the output representation of the encoder and the input to the decoder share the same D dimensional space, the gradients contain useful information for how the encoder has to change its output to lower the reconstruction loss.

 

수식 2는 기울기를 계산할 수 없어 역전파 흐름이 막힌다. 이를 해결하기 위해 디코더 입력(z_q(x))의 기울기를 그대로 가져와 인코더 출력(z_e(x))의 기울기로 사용했다. 디코더 입력과 인코더 출력은 같은 D 차원공간을 공유하므로, 디코더 입력의 기울기에는 "인코더 출력을 (재구성 손실을 낮추기 위한 방향으로) 최적화(수정)하기에 유용한 정보"도 가지고 있다.

 

Equation 3 specifices the overall loss function. It has three components that are used to train different parts of VQ-VAE. The first term is the reconstruction loss (or the data term) which optimizes the decoder and the encoder (through the estimator explained above). Due to the straight-through gradient estimation of mapping from z_e(x) to z_q(x), the embeddings e_i receive no gradients from the reconstruction loss \log{p(x|z_q(x))}. Therefore, in order to learn the embedding space, we use one of the simplest dictionary learning algorithms, Vector Quantisation (VQ). The VQ objectives uses the l_2 error to move the embeddings vectors e_i towards the encoder outputs z_e(x) as shown in the second term of equation 3. Because this loss term is only used for updating the dictionary, one can alternatively also update the dictionary items as function of moving averages of z_e(x) (not used for the experiments in this work).

 

Finally, since the volume of the embedding space is dimensionless, it can grow arbitrarily if the embeddings e_i do not train as fast as the encoder parameters.(?) To make sure the encoder commits to an embedding and its output does not grow, we add a commitment loss, the third term in equation 3. Thus, the total training objective becomes: L = \log{p(x|z_q(x))} + \lVert sg[z_e(x)] - e \rVert_2^2 + \beta \lVert z_e(x) - sg[e] \rVert_2^2 \quad(3)
where sg stands for the stop-gradient operator that is defined as identity at forward computation time and has zero partial derivatives, thus effectively constraining it operand to be a non-updated constant. The decoder optimizes the first term only, the encoder optimises the first and the last loss terms, and the embeddings are optimised by the middle loss term.

 

손실 함수에는 3가지 구성 요소가 있다. (1) \log{p(x|z_q(x))}: 인코더와 디코더를 최적화하는 재구성 손실. (2) \lVert sg[z_e(x)] - e \rVert_2^2: 재구성 손실의 새로운 학습으로 인해 임베딩 벡터는 기울기 전달을 못 받아 학습 불가 상태다.  임베딩 벡터 학습을 위해 l_2 error 항을 추가했다. (3) \lVert z_e(x) - sg[e] \rVert_2^2: 임베딩 공간은 무한하다. 그래서 임베딩 벡터가 인코더 파라미터보다 느리게 학습하면 임베딩 벡터는 무지성 학습을 할 수 있다. (그래서 인코더 파라미터가 임베딩 벡터보다 느리게 학습하면 임베딩 벡터는 무지성 학습을 할 수 있다.) 따라 commitment loss를 추가했다.

sg는 stop-gradient 연산자다. 이 연산자는 역전파 때 피연산자한테 도함수 0을 준다. 즉, 특정 부분만 학습 가능케 하는 연산자다.

디코더는 첫 항만으로 학습되고 인코더는 첫항과 마지막 항으로 학습되며 임베딩 공간은 중간 항으로 학습된다. 

 

Since we assume a uniform prior for z, the KL term that usually appears in the ELBO is constant w.r.t. the encoder parameters and thus be ignored for training.

 

VQ-VAE 학습 시 사전 분포를 균등 분포로 가정했다. 따라서 ELBO의 LK항은 고정되어 고려하지 않아도 된다.

 

In our experiments we define N discrete latents (e.g., we use a field of 32 \times 32 latents for ImageNet, or 8 \times 8 \times 10 for CIFAR 10). The resulting loss L is identical, except that we get an average over N terms for K-means(?) and commitment loss - one for each latent.

 

실험에서 N개 잠재 변수를 사용한다. (ImageNet에서는 32 \times 32개 잠재 변수를, CIFAR10에서는 8 \times 8 \times 10개 잠재 변수를 사용한다.) N개 임베딩 벡터 평균으로 commitment loss와 l_2 error을 계산한다.

 

2.3. Prior

Whilst training the VQ-VAE, the prior is kept constant and uniform. After training, we fit an autoregressive distribution over z, p(z), so that we can generate x via ancestral sampling. We use a PixelCNN over the discrete latents for images.

 

VQ-VAE 학습 시 사전 분포를 균등 분포로 유지하다가, 학습 후 사전 분포를 autoregressive distribution으로 변경 후 autoregressive distribution 학습을 진행한다. 이를 통해 "ancestral sampling"로 x를 생성할 수 있다. 이미지 학습 시 pixelCNN를 autoregressive distribution 모델로 사용한다.

 

3.Experiments

3.1. Comparison with continuous variables

As a first experiment we compare VQ-VAE with normal VAEs (with continuous variables), as well as VIMCO with independent Gaussian or categorical priors. We train models using the same standard VAE architecture on CIFAR10, while varying the latent capacity. The encoder consists of 2 strided convolutional layer with stride 2 and window size 4 \times 4, followed by two residual 3 \times 3 blocks (implemented as ReLU, 3 \times 3 conv, ReLU, 1 \times 1 conv), all having 256 hidden units. The decoder similarly has two residual 3 \times 3 blocks, followed by two transposed convolutions with stride 2 and window size 4 \times 4. We use the ADAM optimiser with learning rate 2e-4 and evaluate the performance after 250000 steps with batch-size 128. For VIMCO we use 50 samples in the multi-sample training objective.

 

The VAE, VQ-VAE and VIMCO models obtain 4.51 bits/dim, 4.67 bits/dim and 5.14 respectively.

 

Our model is the first among those discrete latent variables which challenges the performance of continuous VAEs. Thus, we get very good reconstructions like regular VAEs provide, with the compressed representation that symbolic representation provide.

 

이번 실험에서는, VQ-VAE를 VAE, VIMCO와 비교할 것이다. 모든 모델 표준 VAE 구조를 사용할 것이다. 모델 구조 및 학습 과정은 위와 같다. 실험 결과, 3가지 모델 모두 비슷한 결과를 보여줬다. 이는 논문 목표를 달성했다. 다시 말해, VAE와 비슷한 재구성을 보이는 동시에 데이터 중요 성질(표현)을 유지했다.

 

 

3.2. Images

In this experiment we show that we can model x = 128 \times 128 \times 3 images by compressing them to a z = 32 \times 32 \times 1 discrete space (with K = 512) via a purely deconvolutional p(x|z).

 

We model images by learning a powerful prior (PixelCNN) over z. This allows to not only greatly speed up training and sampling, but also to use the PixelCNNs capacity to capture the global structure instead of the low-level statistics of images.

 

Reconstructions from the 32 \times 32 \times 1 space with discrete latents are shown in Figure 2. Even considering that we greatly reduce the dimensionality with discrete encoding, the reconstructions look only slightly blurrier than the originals.

 

Next, we train a PixelCNN prior on the discretised 32 \times 32 \times 1 latent space. As we only have 1 channel (not 3 as with colours), we only have to use spatial masking in the PixelCNN. 

 

Samples drawn from the PixelCNN were mapped to pixel-space with the decoder of the VQ-VAE and can be seen in Figure 3.

 

이번 실험에서, deconvoluational p(x|z)x = 128 \times 128 \times 3 이미지를  z = 32 \times 32 \times 1 잠재 공간으로 표현했다. 실험 결과, 재구성 시 원본보다 약간 흐릿하게 보인다. 

뿐만 아니라 강력한 사전 분포(PixelCNN)도 학습했다. PixelCNN로 샘플링된 z로 디코더 넣어 결과를 확인해보니 수준 높은 결과를 보여준다. 

 

We also repeat the same experiment for 84 \times 84 \times 3 frames drawn from the DeepMind Lab environment. The reconstructions looked nearly identical to their originals. Samples drawn from the PixelCNN prior trained on the 21 \times 21 \times 1 latent space and decoded to the pixel space using a deconvolutional model decoder can be seen in Figure 4.

 

Finally, we train a second VQ-VAE with a PixelCNN decoder on the top of the 21 \times 21 \times 1 latent space from the first VQ-VAE on DM-LAB frames. This setup typically breaks VAEs as they suffer from "posterior collapse", i.e., the latents are ignored as the decoder is powerful enough to model x perfectly. 

 

We use only three latent variables (each with K=512 and their own embedding space e) at the second stage for modeling the whole image and as such the model cannot reconstruct the image perfectly

 

다른 실험에서, x = 84 \times 84 \times 3 이미지를 z = 21 \times 21 \times 1 잠재 공간을 표현했다. 재구성 시 원본과 거의 같다.

마지막으로 21 \times 21 \times 1 잠재 공간에 pixelCNN decoder를 하나 올려 z = 21 \times 21 \times 1 잠재 공간을 z = 3 잠재 공간으로 한번 더 압축했다. 이 모델은 이미지를 더 많이 압축해 이미지를 완벽히 재구성 하지 못한다. 하지만 posterior collapse 문제를 억제한다.

 

5. Conclusion

In this work, we have introduced VQ-VAE, a new family of models that combine VAEs with vector quantisation to obtain a discrete latent representation. 

 

We have shown that VQ-VAEs are capable of modeling very long term dependencies through their compressed discrete latent space which we have demonstrated by generating 128 \times 128 colour images.

 

 

All these experiments demonstrated that the discrete latent space learnt by VQ-VAEs capture important features of the data in a completely unsupervised manner. Moreover, VQ-VAEs achieve likelihoods that are almost as good as their continuous latent variable counterparts on CIFAR10 data.

 

이번 연구에서는, 이산 잠재 표현을 얻기 위해 VAE와 VQ를 결합한 VQ-VAE를 소개했다.

VQ-VAE는 압축한 이산 잠재 공간을 통해 long term dependencies를 모델링 할 수 있다. 이는 128 \times 128 이미지를 생성을 통해 증명했다. 

실험을 통해, VQ-VAEs는 "이산 잠재 공간이 데이터의 중요 특성을 보존"한다. 뿐만 아니라 "연속 잠재 변수(=표준 VAE)와 거의 같은 likelihood를 유지"한다.