Processing math: 73%

[RL] ColossalAI PPO 코드 리뷰 (작성중..)

2025. 2. 21. 12:44RL

배경

InstructGPT 논문에서는 RLHF의 학습에 PPO 알고리즘이 사용된다고 설명한다.

다만, 개인적으로 RLHF의 목표 함수와 PPO의 손실 함수 간의 연관성을 파악하지 못해 학습 과정이 어떻게 이루어지는지 알기 어려웠다. 이를 이해하기 위해 ColossalAI의 PPO 코드를 살펴봤다.

코드 리뷰

0. PPO 학습 파이프라인

1. 학습 준비: PPOTrainer의 생성자

PPO 학습에 필요한 네 가지 모델이 인자로 전달되는 것을 확인할 수 있다.

- inital_model: SFT 모델

- actor: Policy 모델 (=RL 모델)

- critic: Value 모델 (= Vπ, 추정 상태 가치 함수)

- reward_model: Reward 모델
  * 엄밀히 말하면, reward_model은 Qπ, 추정 행동 가치 함수(추정 Q-함수)다.

2. 학습 과정: PPOTrainer의 fit()

학습 과정은 다음과 같다.

1. 경험 데이터를 수집한다. (L209)

self._collect_phase() → self._make_experience(), self.data_buffer.append()

2. 경험 데이터를 Dataloader에 담는다. (L211)

self._setup_update_phrase_dataload()

3. Dataloader에 담은 경험 데이터를 학습한다. (L213)

self._update_phase() → self._learn()

2.1. 경험 데이터 수집: NaiveExperienceMaker의 make_experience()

여기서 주목해야 할 부분은 RLHF의 목표 함수인 보상(reward) 계산 (L273)PPO 손실 함수에서 사용되는 advantage 계산(L276)이다.

 

2.1.1. 보상 계산: compute_reward()

각 토큰(xt)의 reward는 RseqΔKL(xt)다.

(Rseq: 해당 시퀀스의 reward, ΔKL(xt): xt 토큰의 KL 거리)

reward[i, : action_mask[i].sum()] += r_clip[i]

이와 다르게, HuggingFace의 trl는 마지막 토큰의 reward만 RseqΔKL(xt)로 계산하고,

그 외 reward는  ΔKL(xt)로 계산한다. (개인적으로, reward 계산 방식은 trl가 더 적절한 거 같다.)

kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
# actual_end를 왜 이렇게 계산했는지 의문이다. 
# 내가 이해하기론 sequence_lengths_p1는 stop_token 인덱스 + 1이다. 즉, pad_token 인덱스의 reward에 적용하는 거다.
# 그래서 해당 이슈가 있는지 찾아봤는데, 동일한 의문을 제기한 이슈를 발견했다. 근데 답변이 없다...
# https://github.com/huggingface/trl/issues/1893
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores

 

trl은 내 예상과 달리, 모든 토큰의 advantage value를 계산하여 학습에 사용한다.

그냥, 거리 차이(ΔKL(xt))만 역전파하면 되지 않을까? InstructGPT 논문에서도 그렇게 손실 함수를 정의한거 같은데...

의도를 추측해보자면, 거리 차이만 역전파하는 것보다 advantage value를 역전파하는 것이 더 좋은 학습 정보를 제공하기 때문이 아닐까?

직관적으로, 현재 보상 총합(=ΔKL(xt)+γVt)보다 평균 보상 총합(=Vt1)의 차이가 크면, 학습량이 많아진다. 즉, 생성한 토큰 xt가 적절치 않다는 뜻이다. 반대로, 차이가 작으면 학습량이 작이지며, xt가 적절하다는 뜻이다.

 

2.1.2. advantage 계산: calculate_advantage()

각 토큰의 advantage은 T에서 1까지 거꾸로(reversed) 진행되며, AGAEt+1AGAEt를 누적하는 방식으로 계산된다.
1. δt=Rt+γVt+1Vt (근데, 코드 reward[:, t]는 G(τ)라서 reward[:, t] - value[:, t]이어야 하지 않아?)

step 2. AGAEt=δt+γAGAEt+1

for t in reversed(range(num_actions)):
  delta = reward[:, t] + self.gamma * nextvalues - value[:, t]	# step 1
  lastgaelam = delta + self.gamma * self.lam * lastgaelam	# step 2

2.2. 경험 데이터 학습: PPOTrainer의 _training_step()

여기서 주목해야 할 부분은 actor loss 계산(L233)critic loss 계산(L253)이다.

 

2.2.1. PPO loss 계산:  PolicyLoss의 forward()

step 1. rt

step 2. J_t^\text{clip} = \min \{r_tA(s_t, a_t), \text{clip}(r_t, 1-\epsilon, 1+\epsilon)A(s_t, a_t)\}

# step 1
ratio_ = (log_probs - old_log_probs).exp()
# step 2
surr1 = ratio * advantages						
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages 
loss = -torch.min(surr1, surr2)

 

2.2.2. Value loss 계산:  ValueLoss의 forward()

step 1. target = A_t^\text{GAE} + V_{t_\text{old}} \approx R_t + V_{t+1}
step 2. predict = V_t 이때, 학습 안전성을 위해 V_{t_\text{old}}V_t의 차이가 심할 경우 clipping 해준다.

# step 1
returns = advantage + old_values
# step 2
values_clipped = old_values + \
	(values - old_values).clamp(-self.clip_eps, self.clip_eps
surr1 = (values_clipped - returns) ** 2,
surr2 = (values - returns) ** 2,
loss = torch.mean(torch.max(surr1, surr2))

PPO는 actor와 critic 모델만 학습하며, reward 모델은 학습하지 않는다. 따라서, PPOTrainer 역시 reward 모델을 학습하지 않는다.

그렇기 때문에, 따로 학습을 해줘야 한다.

3. Reward 모델: RewardModel

- reward 모델은 각 스퀀스마다 reward를 계산한다.

- Reward 모델은 선호 시퀀스와 비선호 시퀀스의 보상값 차이를 극대화하는 방식으로 학습되며, 이를 위해 LogExpLoss를 사용한다.

E_{(x, y_w, y_l) \sim \mathcal D} \left[ \log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l))\right]

loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()