Information
Title: Progressive Distillation for Fast Sampling of Diffusion Models (ICLR 2022)
Reference
Author: Sangwoo Jo
Last updated on Nov. 14, 2023
Progressive Distillation for Fast Sampling of Diffusion Models#
1. Introduction#
Diffusion model 이 ImageNet generation task 에서 기존 BigGAN-deep 그리고 VQ-VAE-2 모델보다 FID/CAS score 기준으로 더 좋은 성능을 보여주며 많은 각광을 받고 있습니다. 그러나 sampling 속도가 느리다는 치명적인 단점을 가지고 있습니다.
이를 해결하기 위해, 논문에서는 Progressive Distillation 기법을 소개하게 됩니다. 간략히 설명하자면 사전학습된
2. Background - Diffusion model in continuous time#
2.1. Definition#
Continuous 한 time domain 에서의 diffusion model 을 다음과 같은 요소들로 정의합니다.
Training data
Latent variables
여기서

Fig. 400 Markovian Forward Process#
where
2.2. Objective#
Diffusion model 의 objective 는

Fig. 401 Objective#
where
2.3. Sampling#
Diffusion model 에서 sampling 하는 방식은 다양하게 존재합니다.
2.3.1. Ancestral Sampling - DDPM#
첫번째로는 DDPM 논문에서 소개하는 discrete time ancestral sampling 방식입니다. 위에 소개했던 notation 기준으로 reverse process 를 다음과 같이 수식적으로 표현 가능합니다.

Fig. 402 Reverse Process#
이를 기반으로

Fig. 403 Ancestral Sampler#
2.3.2. Probability Flow ODE#
반면에, Song et al. (2021c) 에서 forward diffusion process 를 SDE 로 표현할 수 있고, 이를 통한 sampling process 를 probabiility flow ODE 로 표현해서 구할 수 있다고 제시합니다.

Fig. 404 Probability flow ODE#
이때,
다시 말해

Fig. 405 FID scores on 128 × 128 ImageNet for various probability flow ODE integrators#
참고로 DDIM sampler 를 ODE solver 문제로 해석하면 다음과 같이 표현할 수 있고, 이 수식은 앞으로 자주 보게 될 예정입니다.

Fig. 406 DDIM sampler#
3. Progressive Distillation#
Diffusion model 을 더 효율적으로 sampling 하기 위해 소개한 progressive distillation 기법은 다음과 같은 절차로 진행됩니다.

Fig. 407 Progressive Distillation#
Standard diffusion training 기법으로 Teacher Diffusion Model 학습
Student Model 정의 - Teacher Model 로부터 모델 구조 및 parameter 복사
Student Model 학습
이때, original data
대신에 를 target 로 student model 을 학습합니다. 에 대한 공식은 아래 pseudocode 에 소개되는데, 이는 one-step student sample 과 two-step teacher sample 를 일치시키기 위해 나온 공식입니다.2 DDIM steps of teacher model 결과와 1 DDIM step of student model 결과를 일치시키는 것이 핵심입니다. 여기서
에서 로 넘어가는 과정을 1 DDIM step 라 정의하고, 은 총 진행되는 student sampling steps 입니다.기존 denoising model 학습 시,
가 에 대해 deterministic 하지 않기 때문에 (다른 값들에 대해 동일한 생성 가능) 모델은 사실상 가 아닌 weighted average of possible values 를 예측하는 모델이라고 합니다. 따라서, 에 대해 deterministic 한 를 예측하도록 학습한 student model 은 teacher model 보다 더 sharp 한 prediction 을 할 수 있다고 주장합니다.
Student Model 이 새로운 Teacher Model 이 되고 sampling steps
→ 로 줄어드는 이 과정을 계속 반복
이에 대한 pseudocode 도 확인해보겠습니다.
PseudoCode
Fig. 408 Pseudocode for Progresssive Distillation#
4. Diffusion Model Parameterization and Training Loss#
이제 denoising model
DDPM 을 비롯한 대다수의 논문에서 이미지

Fig. 409 Training loss on
따라서, 이는 이미지
Standard diffusion training 기법에서는 다양한 범위 내에서의 signal-to-noise ratio
그래서 논문에서는 다음과 같은 세가지 방법으로 stable 한

Fig. 410 Different parameterizations#
Weighting function

Fig. 411 Different loss weighting functions#

Fig. 412 Visualization of different loss weighting functions#
5. Experiments#
논문에서 32x32 부터 128x128 까지 다양한 resolution 에서 모델 성능을 확인했습니다. 또한, cosine schedule
5.1. Model Parametrization and Training Loss#
아래 지표는 unconditional CIFAR-10 데이터셋에 앞써 소개드린

Fig. 413 Ablation Study on Parameterizations and Loss Weightings#
성능을 비교해본 결과
위 실험결과를 바탕으로 progressive distillation 진행시 CIFAR-10 데이터셋에는
5.2. Progressive Distillation#
논문에서 CIFAR-10, 64x64 downsampled ImageNet, 128 × 128 LSUN bedrooms, 그리고 128 × 128 LSUN Church-Outdoor 데이터셋에 progressive distillation 을 적용하여 모델 성능을 측정합니다. CIFAR-10 데이터셋 기준으로 teacher model 로부터 progressive distillation 진행 시 8192 steps 부터 시작하였고 batch size=128 로 설정하였습니다. 그 외 resolution 이 큰 데이터셋에 대해서는 1024 steps 부터 시작하고 batch size=2048 로 실험을 진행했습니다. 또한, 매 iteration 마다
FID 성능을 확인해본 결과, 실험을 진행한 모든 4개의 데이터셋에 대해 progressive distillation 을 통해 4-8 sampling steps 만 진행해도 undistilled DDIM 그리고 stochastic sampler 에 준하는 성능을 보여주는 것을 확인할 수 있습니다. 4 sampling steps 까지 progressive distillation 진행하면서 발생하는 computational cost 가 baseline 모델 학습하는 것과 비슷한 부분을 생각했을때 엄청난 장점이라고 생각합니다.

Fig. 414 Comparison between Distilled, DDIM, and Stochastic Sampler#
추가적으로 CIFAR-10 데이터셋에서 타 fast sampling method 들과 FID 성능을 비교해본 결과입니다.

Fig. 415 Comparison of fast sampling results#
그리고 64x64 ImageNet 데이터셋에 distilled 모델로 생성한 예시 이미지들입니다. 동일한 seed 에 대해서 input noise 로부터 output image 까지 mapping 이 잘되는 부분을 확인할 수 있습니다.

Fig. 416 Random samples from distilled 64 × 64 ImageNet models#
마지막으로 distillation scheduling 에 대한 ablation study 도 논문에서 진행했습니다. 첫번째 ablation study 로는 매 distillation iteration 마다 parameter update 횟수를

Fig. 417 Ablation study on fast sampling schedule#
동일하게 CIFAR-10 외 ImageNet 그리고 LSUN 데이터셋에서 fast sampling schedule 을 적용한 성능 결과도 공유합니다.

Fig. 418 50k updates vs 10k updates on ImageNet/LSUN datasets#