

from __future__ import annotations

import abc
import argparse
import copy
import itertools
from typing import Any, ClassVar, List

import deepspeed
import optree
import torch
import torch.distributed as dist
import torch.nn as nn
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    get_scheduler,
)
from transformers.integrations.deepspeed import HfDeepSpeedConfig

from safe_rlhf.configs import ADAM_BETAS
from safe_rlhf.datasets import DummyDataset, PromptOnlyBatch, PromptOnlyDataset, SupervisedDataset
from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.trainers.base import TrainerBase
from safe_rlhf.utils import (
    get_all_reduce_mean,
    get_optimizer_grouped_parameters,
    is_main_process,
    is_same_tokenizer,
    masked_mean,
    to_device,
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    gather_input_ids,
    gather_scores
)
import os
import json

class PPOEnsembleTrainer(TrainerBase):
    TRAINING_TYPE = 'ppo_ensemble'

    TRAINING_TYPE: ClassVar[str] = 'rl'

    actor_model: deepspeed.DeepSpeedEngine
    actor_reference_model: deepspeed.DeepSpeedEngine
    reward_model: List[deepspeed.DeepSpeedEngine]
    reward_critic_model: deepspeed.DeepSpeedEngine

    reward_tokenizer: List[PreTrainedTokenizerBase]
    reward_critic_tokenizer: PreTrainedTokenizerBase

    gold_model: deepspeed.DeepSpeedEngine
    gold_tokenizer: PreTrainedTokenizerBase

    ds_train_config: dict[str, Any]
    ds_eval_config: dict[str, Any]

    train_generated_data = []

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        """Initialize trainer."""
        self.args = args
        self.ds_train_config = ds_train_config
        self.ds_eval_config = ds_eval_config
        self.global_step = 0

        self.init_models()
        dist.barrier()
        self.init_datasets()
        dist.barrier()
        self.init_engines()
        dist.barrier()
        self.init_logger()

        self.generation_config = GenerationConfig(
            max_length=self.args.max_length,
            num_return_sequences=self.args.num_return_sequences,
            temperature=self.args.temperature,
            top_p=self.args.top_p,
            repetition_penalty=self.args.repetition_penalty,
            do_sample=True,
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )

        # Those value can be changed
        self.kl_coeff = self.args.kl_coeff
        self.clip_range_ratio = self.args.clip_range_ratio
        self.clip_range_score = self.args.clip_range_score
        self.clip_range_value = self.args.clip_range_value
        self.ptx_coeff = self.args.ptx_coeff
        self.gamma = 1.0
        self.gae_lambda = 0.95

    def init_models(self) -> None:
        """Initialize model and tokenizer."""
        if (
            self.ds_train_config is not None
            and self.ds_train_config['zero_optimization']['stage'] == 3
        ):
            self.dstchf_train = HfDeepSpeedConfig(self.ds_train_config)

        if (
            self.ds_eval_config is not None
            and self.ds_eval_config['zero_optimization']['stage'] == 3
        ):
            self.dsechf_eval = HfDeepSpeedConfig(self.ds_eval_config)

        self.actor_model, self.tokenizer = load_pretrained_models(
            self.args.actor_model_name_or_path,
            model_max_length=self.args.max_length,
            padding_side='left',
            auto_model_type=AutoModelForCausalLM,
            trust_remote_code=self.args.trust_remote_code,
        )
        self.actor_reference_model, _ = load_pretrained_models(
            self.args.actor_model_name_or_path,
            model_max_length=self.args.max_length,
            padding_side='left',
            auto_model_type=AutoModelForCausalLM,
            trust_remote_code=self.args.trust_remote_code,
        )

        self.reward_model = []
        self.reward_tokenizer = []
        for reward_model_name_or_path in self.args.reward_model_name_or_path:
            reward_model, reward_tokenizer = load_pretrained_models(
                reward_model_name_or_path,
                model_max_length=self.args.max_length,
                auto_model_type=AutoModelForScore,
                padding_side='right',
                trust_remote_code=self.args.trust_remote_code,
                auto_model_kwargs={
                    'score_type': 'reward',
                    'do_normalize': True,
                },
            )
            reward_model.set_normalize(True)
            self.reward_model.append(reward_model)
            self.reward_tokenizer.append(reward_tokenizer)

        if self.args.reward_critic_model_name_or_path is None:
            raise ValueError('Reward critic model must be provided.')
        self.reward_critic_model, self.reward_critic_tokenizer = load_pretrained_models(
            self.args.reward_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'critic',
                'do_normalize': False,
            },
        )
        self.reward_critic_model.set_normalize(False)

        self.gold_model, self.gold_tokenizer = load_pretrained_models(
            self.args.gold_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'reward',
                'do_normalize': False,
            },
        )
        self.gold_model.set_normalize(False)

        if is_same_tokenizer(self.tokenizer, self.gold_tokenizer):
            self.gold_tokenizer = self.tokenizer

        temp_reward_tokenizer = []
        for reward_tokenizer in self.reward_tokenizer:
            if is_same_tokenizer(self.tokenizer, reward_tokenizer):
                reward_tokenizer = self.tokenizer
            temp_reward_tokenizer.append(reward_tokenizer)
        self.reward_tokenizer = temp_reward_tokenizer

        if not is_same_tokenizer(self.tokenizer, self.reward_critic_tokenizer):
            raise ValueError(
                (
                    'Reward critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--reward_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.reward_critic_tokenizer),
                    len(self.reward_critic_tokenizer),
                ),
            )
        self.reward_critic_tokenizer = self.tokenizer

    def init_datasets(self) -> None:
        """Initialize training and evaluation datasets."""
        if (
            self.args.per_device_prompt_batch_size
            * self.args.num_return_sequences
            % self.args.per_device_train_batch_size
            != 0
        ):
            raise ValueError(
                'The number of prompt-only samples must be divisible by the micro batch size.',
            )

        prompt_only_dataset = PromptOnlyDataset(
            self.args.train_datasets,
            tokenizer=self.tokenizer,
        )

        if self.args.need_eval:
            if self.args.eval_datasets is None and self.args.eval_split_ratio is not None:
                prompt_only_dataset, eval_dataset = prompt_only_dataset.split_train_test(
                    split_ratio=self.args.eval_split_ratio,
                )
            elif self.args.eval_datasets is not None and self.args.eval_split_ratio is None:
                eval_dataset = PromptOnlyDataset(
                    self.args.eval_datasets,
                    tokenizer=self.tokenizer,
                )
            else:
                raise ValueError('Either `eval_datasets` or `eval_split_ratio` should be provided.')

            self.eval_dataloader = DataLoader(
                eval_dataset,
                collate_fn=eval_dataset.get_collator(),
                sampler=DistributedSampler(eval_dataset, shuffle=True),
                batch_size=self.args.per_device_eval_batch_size,
            )
        else:
            self.eval_dataloader = None

        self.prompt_only_dataloader = DataLoader(
            prompt_only_dataset,
            collate_fn=prompt_only_dataset.get_collator(),
            sampler=DistributedSampler(prompt_only_dataset, shuffle=True),
            batch_size=self.args.per_device_prompt_batch_size,
        )

        self.use_ptx = self.args.ptx_datasets is not None
        if self.use_ptx:
            ptx_dataset = SupervisedDataset(
                self.args.ptx_datasets,
                tokenizer=self.tokenizer,
            )

            self.ptx_dataloader = DataLoader(
                ptx_dataset,
                collate_fn=ptx_dataset.get_collator(),
                sampler=DistributedSampler(ptx_dataset, shuffle=True),
                batch_size=self.args.per_device_prompt_batch_size * self.args.num_return_sequences,
            )
        else:
            self.ptx_dataloader = DataLoader(DummyDataset(len(self.prompt_only_dataloader)))

        self.args.total_training_steps = int(
            len(self.prompt_only_dataloader)
            * self.args.epochs
            * self.args.update_iters
            * self.args.per_device_prompt_batch_size
            * self.args.num_return_sequences
            // self.args.per_device_train_batch_size,
        )

    def _init_train_engine(
        self,
        model: nn.Module,
        weight_decay: float,
        lr: float,
        lr_scheduler_type: str,
        lr_warmup_ratio: float,
        total_training_steps: int,
        ds_config: dict[str, Any],
    ) -> deepspeed.DeepSpeedEngine:
        optimizer_grouped_parameters = get_optimizer_grouped_parameters(model, weight_decay)
        if (
            ds_config['zero_optimization'].get('offload_optimizer', {}).get('device', 'none')
            != 'none'
        ):
            optimizer = DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=lr, betas=ADAM_BETAS)
        else:
            optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=ADAM_BETAS)

        lr_scheduler_update_steps = total_training_steps // ds_config['gradient_accumulation_steps']
        num_warmup_steps = int(lr_scheduler_update_steps * lr_warmup_ratio)
        lr_scheduler = get_scheduler(
            name=lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=lr_scheduler_update_steps,
        )
        engine, *_ = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            config=ds_config,
        )
        return engine

    def _init_eval_engine(
        self,
        model: nn.Module,
        ds_config: dict[str, Any],
    ) -> deepspeed.DeepSpeedEngine:
        engine, *_ = deepspeed.initialize(
            model=model,
            config=ds_config,
        )
        return engine

    def init_engines(self) -> None:
        """Initialize DeepSpeed engines."""
        actor_ds_config = copy.deepcopy(self.ds_train_config)
        actor_total_training_steps = self.args.total_training_steps
        if self.use_ptx:
            self.args.gradient_accumulation_steps *= 2
            actor_ds_config['train_batch_size'] *= 2
            actor_ds_config['gradient_accumulation_steps'] *= 2
            actor_total_training_steps *= 2
        self.actor_model = self._init_train_engine(
            model=self.actor_model,
            weight_decay=self.args.actor_weight_decay,
            lr=self.args.actor_lr,
            lr_scheduler_type=self.args.actor_lr_scheduler_type,
            lr_warmup_ratio=self.args.actor_lr_warmup_ratio,
            total_training_steps=actor_total_training_steps,
            ds_config=actor_ds_config,
        )

        self.actor_reference_model = self._init_eval_engine(
            model=self.actor_reference_model,
            ds_config=self.ds_eval_config,
        )
        self.actor_reference_model.eval()

        self.reward_critic_model = self._init_train_engine(
            model=self.reward_critic_model,
            weight_decay=self.args.critic_weight_decay,
            lr=self.args.critic_lr,
            lr_scheduler_type=self.args.critic_lr_scheduler_type,
            lr_warmup_ratio=self.args.critic_lr_warmup_ratio,
            total_training_steps=self.args.total_training_steps,
            ds_config=self.ds_train_config,
        )

        temp_reward_model = []
        for reward_model in self.reward_model:
            temp_model = self._init_eval_engine(
                model=reward_model,
                ds_config=self.ds_eval_config,
            )
            temp_model.eval()
            temp_reward_model.append(temp_model)
        self.reward_model = temp_reward_model

        self.gold_model = self._init_eval_engine(
            model=self.gold_model,
            ds_config=self.ds_eval_config,
        )
        self.gold_model.eval()

        if self.args.actor_gradient_checkpointing:
            self.actor_model.gradient_checkpointing_enable()
        if self.args.critic_gradient_checkpointing:
            self.reward_critic_model.gradient_checkpointing_enable()

    def set_train(self, mode: bool = True) -> None:
        """Set training mode for all models."""
        if mode:
            self.actor_model.train()
            self.reward_critic_model.train()

            if self.args.actor_gradient_checkpointing:
                self.actor_model.gradient_checkpointing_enable()
        else:
            self.actor_model.eval()
            self.reward_critic_model.eval()

            if self.args.actor_gradient_checkpointing:
                self.actor_model.gradient_checkpointing_disable()

    def split_rl_micro_batches(
        self,
        prompt_only_batch: PromptOnlyBatch,
    ) -> list[PromptOnlyBatch]:
        """Split a batch of RL samples into micro-batches."""
        total_batch_size = prompt_only_batch['input_ids'].size(0)
        micro_batch_size = self.args.per_device_train_batch_size
        micro_batches = []
        for i in range(0, total_batch_size, micro_batch_size):
            micro_batch = optree.tree_map(
                # pylint: disable-next=cell-var-from-loop
                lambda tensor: tensor[i : i + micro_batch_size],  # noqa: B023
                prompt_only_batch,
            )
            micro_batches.extend(self.rollout(micro_batch))
        return micro_batches

    def split_ptx_micro_batches(
        self,
        ptx_batch: dict[str, torch.Tensor],
    ) -> list[dict[str, torch.Tensor]]:
        """Split a batch of PTX samples into micro-batches."""
        micro_batches = []
        total_batch_size = ptx_batch['input_ids'].size(0)
        micro_batch_size = self.args.per_device_train_batch_size
        for i in range(0, total_batch_size, micro_batch_size):
            micro_batch = optree.tree_map(
                # pylint: disable-next=cell-var-from-loop
                lambda tensor: tensor[i : i + micro_batch_size],  # noqa: B023
                ptx_batch,
            )
            micro_batches.append(micro_batch)
        return micro_batches

    @torch.no_grad()
    def rollout(self, prompt_only_batch: PromptOnlyBatch) -> list[dict[str, Any]]:
        """Rollout a batch of experiences."""
        input_ids = prompt_only_batch['input_ids']
        sequences = self.actor_model.module.generate(
            input_ids=input_ids,
            attention_mask=prompt_only_batch['attention_mask'],
            generation_config=self.generation_config,
            synced_gpus=True,
            do_sample=True,
        )
        sequences = (
            sequences.contiguous()
            .view(
                input_ids.size(0),
                self.args.num_return_sequences,
                -1,
            )
            .transpose(0, 1)
        )

        return [
            self.post_rollout(
                input_ids,
                seq,
                attention_mask=torch.logical_and(
                    seq.not_equal(self.tokenizer.pad_token_id),
                    seq.not_equal(self.tokenizer.unk_token_id),
                ),
            )
            for seq in sequences
        ]

    def ptx_step(self, ptx_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        """Perform a single update step with PTX loss."""
        ptx_loss = self.actor_model(
            input_ids=ptx_batch['input_ids'],
            attention_mask=ptx_batch['attention_mask'],
            labels=ptx_batch['labels'],
        ).loss
        self.actor_model.backward(self.ptx_coeff * ptx_loss)
        self.actor_model.step()

        ptx_loss = get_all_reduce_mean(ptx_loss)

        return {
            'train/ptx_loss': ptx_loss.item(),
        }

    def train(self) -> None:
        """Train the model."""
        self.logger.print('***** Running training *****')

        progress_bar = tqdm(
            total=self.args.total_training_steps,
            desc=f'Training 1/{self.args.epochs} epoch',
            position=0,
            leave=True,
            disable=not is_main_process(),
        )

        if self.args.need_eval:
            self.logger.print('\n***** Evaluating at the beginning *****')
            self.logger.log(self.eval(), step=0)

        num_prompt_only_batches = len(self.prompt_only_dataloader)
        num_ptx_batches = len(self.ptx_dataloader)
        num_ptx_replicas = (num_prompt_only_batches + num_ptx_batches - 1) // num_ptx_batches
        for epoch in range(self.args.epochs):
            for prompt_only_batch, ptx_batch in zip(
                self.prompt_only_dataloader,
                itertools.chain.from_iterable([self.ptx_dataloader] * num_ptx_replicas),
            ):
                # generate batches
                self.set_eval()
                prompt_only_batch = to_device(prompt_only_batch, self.args.device)
                rl_batches = self.split_rl_micro_batches(prompt_only_batch)
                if self.use_ptx:
                    ptx_batch = to_device(ptx_batch, self.args.device)
                    ptx_batches = self.split_ptx_micro_batches(ptx_batch)
                else:
                    ptx_batches = [None for _ in range(len(rl_batches))]
                torch.cuda.empty_cache()

                # train
                self.set_train()
                for _ in range(self.args.update_iters):
                    for rl_batch, ptx_batch in zip(rl_batches, ptx_batches):
                        rl_info = self.rl_step(rl_batch)
                        torch.cuda.empty_cache()
                        self.logger.log(rl_info, step=self.global_step)
                        if self.use_ptx:
                            ptx_info = self.ptx_step(ptx_batch)
                            torch.cuda.empty_cache()
                            self.logger.log(ptx_info, step=self.global_step)

                        self.global_step += 1
                        progress_bar.set_description(
                            f'Training {epoch + 1}/{self.args.epochs} epoch '
                            f'(reward {rl_info["train/reward"]:.4f})',
                        )
                        progress_bar.update(1)

                        if self.global_step % self.args.save_interval == 0:
                            self.logger.print(f'Saving checkpoint at step {self.global_step} ...')
                            save_dir = os.path.join(self.args.output_dir, f'{self.global_step}')
                            os.makedirs(save_dir, exist_ok=True)
                            if self.args.save_16bit:
                                self.actor_model.save_16bit_model(save_dir)
                            else:
                                self.actor_model.save_checkpoint(save_dir)
                            self.logger.print('Checkpoint saved.')

                        if (
                            self.args.need_eval
                            and self.args.eval_strategy == 'steps'
                            and self.global_step % self.args.eval_interval == 0
                        ):
                            self.logger.print(
                                f'\n***** Evaluating at step {self.global_step} *****',
                            )
                            self.logger.log(self.eval(), step=self.global_step)

            if self.args.need_eval and self.args.eval_strategy == 'epoch':
                self.logger.print(
                    f'\n***** Evaluating at epoch {epoch + 1}/{self.args.epochs} *****',
                )
                self.logger.log(self.eval(), step=self.global_step)

    def eval(self) -> dict[str, Any]:
        """Evaluate the model on the evaluation dataset."""
        if self.eval_dataloader is None:
            return {}

        self.set_eval()
        prompts: list[str] = []
        generateds: list[str] = []
        scores: dict[str, list[torch.Tensor]] = {}

        eval_dataloader = tqdm(
            self.eval_dataloader,
            desc='Evaluating',
            disable=not is_main_process(),
        )

        for batch in eval_dataloader:
            batch = to_device(batch, self.args.device)
            with torch.no_grad():
                seq = self.actor_model.module.generate(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    max_length=self.args.max_length,
                    synced_gpus=True,
                    do_sample=True,
                )

            dist.barrier()

            attention_mask = torch.logical_and(
                seq.not_equal(self.tokenizer.pad_token_id),
                seq.not_equal(self.tokenizer.unk_token_id),
            )
            for key, values in self.eval_step(batch['input_ids'], seq, attention_mask).items():
                if key not in scores:
                    scores[key] = []
                scores[key].append(values)
            prompt = self.tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
            generated = self.tokenizer.batch_decode(seq, skip_special_tokens=True)
            generated = [text[len(prompt[i]) :] for i, text in enumerate(generated)]
            prompts.extend(prompt)
            generateds.extend(generated)

        # Display result in main process
        if is_main_process():
            columns = ['Prompt', 'Generated', *list(scores.keys())]
            concatenated_scores = {
                key: torch.cat(value, dim=0).to(torch.float32) for key, value in scores.items()
            }
            concatenated_scores = {
                key: value.tolist() for key, value in concatenated_scores.items()
            }
            rows = list(zip(prompts, generateds, *concatenated_scores.values()))
            self.logger.print_table(
                title='Evaluating...',
                columns=columns,
                rows=rows,
                max_num_rows=5,
            )

        # Gather results from all processes
        for key, values in scores.items():
            scores[key] = torch.cat(values, dim=0).mean()
            dist.reduce(scores[key], dst=0, op=dist.ReduceOp.AVG)
            scores[key] = scores[key].mean().item()
        dist.barrier()

        self.set_train()

        return scores

    def get_advantages_and_returns(
        self,
        values: torch.Tensor,
        rewards: torch.Tensor,
        sequence_mask: torch.BoolTensor,
        start: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute advantages and returns using Generalized Advantage Estimation (GAE)."""
        # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py
        last_gae_lambda = 0.0
        advantages_reversed = []
        values = values * sequence_mask
        rewards = rewards * sequence_mask
        length = rewards.size(-1)
        for t in reversed(range(start, length)):  # pylint: disable=invalid-name
            next_values = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * next_values - values[:, t]
            last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda
            advantages_reversed.append(last_gae_lambda)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values[:, start:]
        return advantages.detach(), returns

    def critic_loss_fn(
        self,
        values: torch.Tensor,  # size = (B, L - S)
        old_values: torch.Tensor,  # size = (B, L - S)
        returns: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        """Compute critic loss."""
        # size = (B, L - S)
        values_clipped = torch.clamp(
            values,
            old_values - self.clip_range_value,
            old_values + self.clip_range_value,
        )
        vf_loss1 = torch.square(values - returns)
        vf_loss2 = torch.square(values_clipped - returns)
        return 0.5 * masked_mean(torch.maximum(vf_loss1, vf_loss2), mask)  # size = ()

    def save(
        self,
        model: deepspeed.DeepSpeedEngine | PreTrainedModel | None = None,
        ds_config: dict | None = None,
    ) -> None:
        """Save model and tokenizer."""
        if model is None:
            model = self.actor_model
        if ds_config is None:
            ds_config = self.ds_train_config
        super().save(model=model, ds_config=ds_config)


    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        all_score = []

        for reward_tokenizer, reward_model in zip(self.reward_tokenizer, self.reward_model):
            if reward_tokenizer is not self.tokenizer:
                reward_tokenize_output = batch_retokenize(
                    sequence,
                    src_tokenizer=self.tokenizer,
                    dest_tokenizer=reward_tokenizer,
                    truncation=True,
                    skip_special_tokens=True,
                    device=self.args.device,
                )
                reward_seq = reward_tokenize_output['input_ids']
                reward_attention_mask = reward_tokenize_output['attention_mask']
            else:
                reward_seq = sequence
                reward_attention_mask = attention_mask
            reward_output = reward_model(
                reward_seq,
                attention_mask=reward_attention_mask,
            )
            reward = reward_output.end_scores.squeeze(dim=-1)
            all_score.append(reward)
        all_score = torch.stack(all_score, dim=-1) # (B, N)

        if self.gold_tokenizer is not self.tokenizer:
            gold_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.gold_tokenizer,
                truncation=True,
                skip_special_tokens=True,
                device=self.args.device,
            )
            gold_input_ids = gold_tokenize_output['input_ids']
            gold_attention_mask = gold_tokenize_output['attention_mask']
        else:
            gold_input_ids = sequence
            gold_attention_mask = attention_mask

        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        gold_reward = self.gold_model(gold_input_ids, attention_mask=gold_attention_mask).end_scores

        reward_values = self.reward_critic_model(sequence, attention_mask=attention_mask).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]

        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'reward': all_score,
            'gold_score': gold_reward,
            'reward_values': reward_values,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

    @torch.no_grad()
    def eval_step(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        all_score = []
        for reward_tokenizer, reward_model in zip(self.reward_tokenizer, self.reward_model):
            if reward_tokenizer is not self.tokenizer:
                reward_tokenize_output = batch_retokenize(
                    sequence,
                    src_tokenizer=self.tokenizer,
                    dest_tokenizer=reward_tokenizer,
                    truncation=True,
                    skip_special_tokens=True,
                    device=self.args.device,
                )
                reward_seq = reward_tokenize_output['input_ids']
                reward_attention_mask = reward_tokenize_output['attention_mask']
            else:
                reward_seq = sequence
                reward_attention_mask = attention_mask
            reward_output = reward_model(
                reward_seq,
                attention_mask=reward_attention_mask,
            )
            reward = reward_output.end_scores.squeeze(dim=-1)
            all_score.append(reward)
        all_score = torch.stack(all_score, dim=-1) # (B, N)
        reward_var = torch.var(all_score, dim=-1)
        reward_mean = torch.mean(all_score, dim=-1)
        if self.args.ensemble_type == 'mean':
            reward = reward_mean
        elif self.args.ensemble_type == 'worst':
            reward = all_score.min(dim=-1).values
        elif self.args.ensemble_type == 'uncertainty-weighted':
            reward = reward_mean - self.args.UWO_coeff * reward_var
        else:
            raise ValueError(f'Unknown ensemble type: {self.args.ensemble_type}')

        if self.gold_tokenizer is not self.tokenizer:
            gold_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.gold_tokenizer,
                truncation=True,
                skip_special_tokens=True,
                device=self.args.device,
            )
            gold_input_ids = gold_tokenize_output['input_ids']
            gold_attention_mask = gold_tokenize_output['attention_mask']
        else:
            gold_input_ids = sequence
            gold_attention_mask = attention_mask

        gold = self.gold_model(
            gold_input_ids,
            attention_mask=gold_attention_mask,
        ).end_scores.squeeze(dim=-1)

        logits = self.actor_model(sequence, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])

        ref_logits = self.actor_reference_model(
            sequence,
            attention_mask=attention_mask,
            use_cache=False,
        ).logits
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        start = prompt.size(-1) - 1
        mask = attention_mask[:, 1+start:]
        mean_generated_length = mask.sum(dim=-1).float()
        kl_seq = ((log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1)
        return {
            'eval/reward': reward,
            'eval/gold_reward': gold,
            'eval/kl_divergence': kl_seq,
            'eval/mean_generated_length': mean_generated_length,
        }

    def add_kl_divergence_regularization(
        self,
        reward: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> torch.Tensor:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate
        rewards = torch.scatter_add(
            kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=reward.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        return torch.clamp(rewards, min=-self.clip_range_score, max=self.clip_range_score)

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        old_log_probs: torch.Tensor,  # size = (B, L - S)
        advantages: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        # size = (B, L - S)
        ratios = torch.exp(log_probs - old_log_probs)
        surrogate1 = advantages * ratios
        surrogate2 = advantages * torch.clamp(
            ratios,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        surrogate = torch.minimum(surrogate1, surrogate2)
        return -masked_mean(surrogate, mask)  # size = ()

    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        reward = rl_batch['reward']
        gold_score = rl_batch['gold_score']
        old_reward_values = rl_batch['reward_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():
            reward_var = torch.var(reward, dim=-1)
            reward_mean = torch.mean(reward, dim=-1)
            if self.args.ensemble_type == 'mean':
                reward = reward_mean
            elif self.args.ensemble_type == 'worst':
                reward = reward.min(dim=-1).values
            elif self.args.ensemble_type == 'uncertainty-weighted':
                reward = reward_mean - self.args.UWO_coeff * reward_var
            else:
                raise ValueError(f'Unknown ensemble type: {self.args.ensemble_type}')
            old_rewards = self.add_kl_divergence_regularization(
                reward,
                prompt,
                old_log_probs,
                ref_log_probs,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        with torch.no_grad():
            mask = sequence_mask[:, start:]
            kl_divergence = ((old_log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1).mean()
            mean_generated_length = mask.sum(dim=-1).float().mean()
            max_generated_length = mask.sum(dim=-1).float().max()

            if self.args.save_train_data:
                prompt_ids = input_ids[:, :start+1].contiguous()
                response_ids = input_ids[:, start+1:].contiguous()
                gathered_prompt_ids = gather_input_ids(prompt_ids, self.tokenizer.pad_token_id)
                gathered_response_ids = gather_input_ids(response_ids, self.tokenizer.pad_token_id)
                gathered_gold = gather_scores(gold_score)
                gathered_reward = gather_scores(reward)
                prompt_texts = self.tokenizer.batch_decode(gathered_prompt_ids, skip_special_tokens=True)
                response_texts = self.tokenizer.batch_decode(gathered_response_ids, skip_special_tokens=True)
                if is_main_process():
                    assert len(prompt_texts) == len(response_texts) == len(gathered_gold) == len(gathered_reward)
                    self.train_generated_data.append({
                        'step': self.global_step,
                        'kl_divergence': kl_divergence.item(),
                        'mean_generated_length': mean_generated_length.item(),
                        'generated_data': [{
                            'prompt': prompt,
                            'response': response,
                            'gold': gold.item(),
                            'proxy': proxy.item(),
                        } for prompt, response, gold, proxy in zip(prompt_texts, response_texts, gathered_gold, gathered_reward)],
                    })

            reward = reward.mean()
            gold_score = gold_score.mean()
            reward_mean = reward_mean.mean()
            reward_var = reward_var.mean()
            reward_with_kl_penalty = (old_rewards[:, start:] * mask).sum(dim=-1).mean()
            reward_advantage = masked_mean(reward_advantages, mask)
            reward_return = masked_mean(reward_returns, mask)
            reward_value = masked_mean(reward_values[:, start:], mask)

            actor_loss = get_all_reduce_mean(actor_loss)
            reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
            reward = get_all_reduce_mean(reward)
            gold_score = get_all_reduce_mean(gold_score)
            reward_mean = get_all_reduce_mean(reward_mean)
            reward_var = get_all_reduce_mean(reward_var)
            reward_with_kl_penalty = get_all_reduce_mean(reward_with_kl_penalty)
            reward_advantage = get_all_reduce_mean(reward_advantage)
            reward_return = get_all_reduce_mean(reward_return)
            reward_value = get_all_reduce_mean(reward_value)
            kl_divergence = get_all_reduce_mean(kl_divergence)
            mean_generated_length = get_all_reduce_mean(mean_generated_length)
            max_generated_length = get_all_reduce_max(max_generated_length)

        dist.barrier()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/reward': reward.item(),
            'train/gold_reward': gold_score.item(),
            'train/reward_mean': reward_mean.item(),
            'train/reward_var': reward_var.item(),
            'train/reward_with_kl_penalty': reward_with_kl_penalty.item(),
            'train/reward_advantage': reward_advantage.item(),
            'train/reward_return': reward_return.item(),
            'train/reward_value': reward_value.item(),
            'train/kl_divergence': kl_divergence.item(),
            'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
            'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
            'train/mean_generated_length': mean_generated_length.item(),
            'train/max_generated_length': max_generated_length.item(),
        }
