[RL] 8.1. DPO(Direct Preference Optimization): Your Language Model is Secretly a Reward Model

2024. 5. 23. 16:13RL

Preview RLHF

preview PPO

 

Reward Model Loss function

$$\begin{matrix} \mathbb (r_\phi, \mathcal D) = - \mathbb E_{(x, y_w, y_l) \sim \mathcal D} \left[ \log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l))\right] \\ \text{where } \mathcal D = \{x^{(i)}, y_w^{(i)}, y_l^{(i)}\}_{i=1}^N\end{matrix}$$

 

RLHF Objective function

$$\begin{matrix} \text{objective} (\theta) &=& \mathbb E_{x \sim \mathcal D, y \sim \pi_\theta(y|x)} [r_\phi (x, y) - \beta \log (\pi_\theta(y|x)/\pi_\text{ref} (y|x))] \\ &=& \mathbb E_{x \sim \mathcal D, y \sim \pi_\theta(y|x)} [r_\phi (x, y)] - \beta D_\text{KL} [\pi_\theta(y|x) || \pi_\text{ref}(y|x)] \end{matrix}$$

 

PPO의 문제점

1. 구현 과정이 복잡하다.

2. 알고리즘이 매우 무겁다. 다시 말해 학습에 필요한 요소가 많고, 학습 과정이 복잡하다.

3. reward의 variance가 커서 학습이 불안정하다. 

DPO

DPO 아이디어

RLHF에서는 preference data로 reward model를 fitting 한 후, 이로 LM Policy($\pi_\theta$)를 RL로 학습시킨다.

반면, DPO에서는 preference data로 final LM($\pi_\theta$)를 직접 학습시킨다.

 

DPO 목표 함수 도출 과정

RLHF Objective function을 기반으로 최적 정책의 수식을 추론해보자.

$$\begin{matrix} \underset{\theta} \max \text{objective} (\theta) &=& \underset{\theta} \max \mathbb E_{x \sim \mathcal D, y \sim \pi_\theta(y|x)} [r (x, y)] - \beta D_\text{KL} [\pi_\theta(y|x) || \pi_\text{ref}(y|x)] \\ &=& \underset{\theta} \max \mathbb E_{x \sim \mathcal D}\mathbb E_{y \sim \pi_\theta(y|x)} [r (x, y) - \beta \log (\pi_\theta(y|x)/\pi_\text{ref} (y|x))] \\ &=& \underset{\theta} \min \mathbb E_{x \sim \mathcal D}\mathbb E_{y \sim \pi_\theta(y|x)} \left[ \log {\pi_\theta(y|x) \over \pi_\text{ref} (y|x)} -{1 \over \beta} r(x, y)\right] \\ &=& \underset{\theta} \min \mathbb E_{x \sim \mathcal D} \mathbb E_{y \sim \pi_\theta(y|x)}\left[\log {\pi_\theta(y|x) \over {1 \over Z(x)} \pi_\text{ref}(y|x)\exp\left({1 \over \beta} r(x, y) \right)} - \log Z(x) \right] \end{matrix}$$

$$Z(x) = \sum_y \pi_\text{ref}(y|x)\exp\left({1\over \beta} r(x,y)\right)$$

 

이때, (위 전개 과정에서 착안해) 최적 정책을 다음과 같이 정의해 볼 수 있다.

$$\pi^*(y|x) = \frac{1}{Z(x)} \pi_\text{ref}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right), \\ (\text{where } \pi^*(y|x) \le 0 \text{ for all } y \text{ and } \sum_y \pi^*(y|x) = 1)$$

 

그렇게 되면, 목표 함수가 $\underset{\theta} \min \mathbb E_{x \sim \mathcal D}\left[ D_\text{KL} \left(\pi_\theta(y|x) || \pi^*(y|x)\right) - \log Z(x)\right]$로 표현되고,

$Z(x)$는 $\pi$에 영향을 받지 않기 때문에 해당 식이 최솟값이 되는 경우는 $\pi_\theta$와 $\pi^*$가 동일할 때가 되므로, 위와 같이 정의한 것이다.

 

수식을 "높은 확률로 생성된 응답이 높은 reward를 가질 경우, 해당 응답을 높은 확률로 선택하는 정책"으로 해석할 수 있다. 

 

이로 보상 함수를 reparameterize하면 다음과 같다.

$$r(x, y) = \beta \log {\pi^*(y|x) \over \pi_\text{ref} (y|x)} + \beta \log Z(x)$$

 

해당 수식은 "특정 응답의 생성 확률이 reference policy(=model)보다 높으면 보상이 양수고 낮으면 보상이 마이너스를 준다".

 

해당 식을 Reward Model Loss function에 대입하면 다음과 같다.

$$\begin{matrix} - \mathbb E_{(x, y_w, y_l) \sim \mathcal D} \left[ \log \sigma(r(x, y_w) - r(x, y_l))\right] \quad (\text{where } \mathcal D = \{x^{(i)}, y_w^{(i)}, y_l^{(i)}\}_{i=1}^N) \\ = - \mathbb E_{(x, y_w, y_l) \sim \mathcal D} \left[ \log \sigma \left(\beta \log {\pi_\theta(y_w|x) \over \pi_\text{ref}(y_w|x)} - \beta \log  {\pi_\theta(y_l|x) \over \pi_\text{ref}(y_l|x)} \right)\right] \end{matrix}$$

 

$$\mathcal L_\text{DPO}(\pi_\theta, \pi_\text{ref}) = - \mathbb E_{(x, y_w, y_l) \sim \mathcal D} \left[ \log \sigma \left(\underbrace{\beta \log {\pi_\theta(y_w|x) \over \pi_\text{ref}(y_w|x)}}_{\begin{matrix} \text{implicit reward of} \\ \text{preferred response}\end{matrix}} - \underbrace{\beta \log  {\pi_\theta(y_l|x) \over \pi_\text{ref}(y_l|x)}}_{\begin{matrix} \text{implicit reward of} \\ \text{dispreferred response}\end{matrix}} \right)\right]$$

 

이때, reward의 variance가 확률 범위(0~1)로 normalized 되어 학습이 상대적으로 안정적이다.

 

DPO 목표 함수 미분 해석

$$\nabla_\theta \mathcal L_\text{DPO}(\pi_\theta, \pi_\text{ref}) = - \mathbb E_{(x, y_w, y_l) \sim \mathcal D} \left[ \underbrace{\sigma(\hat r_\theta(x, y_l) - \hat r_\theta(x, y_w))}_{\begin{matrix}\text{higher weight when} \\ \text{reward estimate is wrong}\end{matrix}} \left[ \underbrace{\nabla_\theta \log \pi_\theta(y_w|x)}_{\text{increrase likelihood of }y_w} - \underbrace{\nabla_\theta \log \pi_\theta(y_l|x)}_{\text{decrease likelihood of } y_l} \right]\right]$$

($\hat r_\theta(x, y)$ is $\beta \log{\pi_\theta(y|x) \over \pi_\text{ref}(y|x)}$, implicit reward of $y$)

 

결론, reward model을 reparameterize해서 optimal policy를 RL 단계 없이 도출시킬 수 있게 되었다.

다시 말해, reward model과 prompts 없이 preference data로 $\pi_\theta$를 직접 학습시켜 optimal policy를 도출할 수 있게 되었다.

저자들은 reward 모델이 implicit하게 존재하고 있으므로, optimal policy 도출 과정을 implicit reward 모델을 human preference에 fit하는 과정으로 볼 수 있다고 주장한다.

장점

1. simplify, stable, low cost 되었다.

PPO는 KL-constraint에 따라 불안정하게 학습된다. 하지만, DPO는 안정적으로 학습이 이뤄진다. 뿐만 아니라, temperature에도 강건하게 반응한다.

2. reward reparameterization을 통해 RL 단계 없이 optimal policy 학습 가능

 

단점

1. generalization에 대한 이론적 근거 미약

RL(PPO)는 unlabeled prompt를 사용해서 좀 더 다양한 탐색을 해가지고 일반화 성능을 올리는 방식인데 DPO는 unlabeled prompt를 사용하지 않음에도 성능이 좋다?

2. reward over-optimization(=overfitting)