

from __future__ import annotations
from collections import deque


import abc
import argparse
import copy
import itertools
from typing import Any, ClassVar

import deepspeed
import optree
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    get_scheduler,
)
from transformers.integrations.deepspeed import HfDeepSpeedConfig

from safe_rlhf.configs import ADAM_BETAS
from safe_rlhf.datasets import DummyDataset, PromptOnlyBatch, PromptOnlyDataset, SupervisedDataset
from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.trainers.base import TrainerBase
from safe_rlhf.utils import (
    get_all_reduce_mean,
    get_optimizer_grouped_parameters,
    is_main_process,
    is_same_tokenizer,
    masked_mean,
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    to_device,
    gather_input_ids,
    gather_scores
)
import os

class ConstrainedPPOTrainer(TrainerBase):
    TRAINING_TYPE = 'constrained_ppo'

    cost_model: deepspeed.DeepSpeedEngine
    cost_critic_model: deepspeed.DeepSpeedEngine

    cost_tokenizer: PreTrainedTokenizerBase
    cost_critic_tokenizer: PreTrainedTokenizerBase

    actor_model: deepspeed.DeepSpeedEngine
    actor_reference_model: deepspeed.DeepSpeedEngine
    reward_critic_model: deepspeed.DeepSpeedEngine

    reward_critic_tokenizer: PreTrainedTokenizerBase

    gold_model: deepspeed.DeepSpeedEngine
    gold_tokenizer: PreTrainedTokenizerBase

    ds_train_config: dict[str, Any]
    ds_eval_config: dict[str, Any]

    train_generated_data = []

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        """Initialize trainer."""
        self.args = args
        self.ds_train_config = ds_train_config
        self.ds_eval_config = ds_eval_config
        self.global_step = 0

        self.init_models()
        dist.barrier()
        self.init_datasets()
        dist.barrier()
        self.init_engines()
        dist.barrier()
        self.init_logger()

        self.generation_config = GenerationConfig(
            max_length=self.args.max_length,
            num_return_sequences=self.args.num_return_sequences,
            temperature=self.args.temperature,
            top_p=self.args.top_p,
            repetition_penalty=self.args.repetition_penalty,
            do_sample=True,
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )

        # Those value can be changed
        self.clip_range_ratio = self.args.clip_range_ratio
        self.clip_range_score = self.args.clip_range_score
        self.clip_range_value = self.args.clip_range_value
        self.ptx_coeff = self.args.ptx_coeff
        self.gamma = 1.0
        self.gae_lambda = 0.95
        self.kl_coeff = self.args.kl_coeff

        # Lagrange multiplier
        if self.args.squash_fn == 'tanh':
            self.squash_fn = torch.tanh
        elif self.args.squash_fn == 'sigmoid':
            self.squash_fn = torch.sigmoid
        elif self.args.squash_fn == 'identity':
            self.squash_fn = lambda x: x
        else:
            raise ValueError(f'Invalid squash function: {self.args.squash_fn}')
        self.lagrange = torch.nn.Parameter(
            torch.tensor(self.args.lagrange_init, device=self.args.device),
            requires_grad=True,
        )
        self.lagrange_max = self.args.lagrange_max
        self.lagrange_optimizer = torch.optim.SGD([self.lagrange], lr=self.args.lagrange_lr)
        self.episode_costs = deque(maxlen=self.args.episode_cost_window_size)
        self.threshold = self.args.threshold

    def init_models(self) -> None:
        if (
            self.ds_train_config is not None
            and self.ds_train_config['zero_optimization']['stage'] == 3
        ):
            self.dstchf_train = HfDeepSpeedConfig(self.ds_train_config)

        if (
            self.ds_eval_config is not None
            and self.ds_eval_config['zero_optimization']['stage'] == 3
        ):
            self.dsechf_eval = HfDeepSpeedConfig(self.ds_eval_config)

        self.actor_model, self.tokenizer = load_pretrained_models(
            self.args.actor_model_name_or_path,
            model_max_length=self.args.max_length,
            padding_side='left',
            auto_model_type=AutoModelForCausalLM,
            trust_remote_code=self.args.trust_remote_code,
        )
        self.actor_reference_model, _ = load_pretrained_models(
            self.args.actor_model_name_or_path,
            model_max_length=self.args.max_length,
            padding_side='left',
            auto_model_type=AutoModelForCausalLM,
            trust_remote_code=self.args.trust_remote_code,
        )

        self.reward_critic_model, self.reward_critic_tokenizer = load_pretrained_models(
            self.args.reward_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'critic',
                'do_normalize': False,
            },
        )
        self.reward_critic_model.set_normalize(False)

        if not is_same_tokenizer(self.tokenizer, self.reward_critic_tokenizer):
            raise ValueError(
                (
                    'Reward critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--reward_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.reward_critic_tokenizer),
                    len(self.reward_critic_tokenizer),
                ),
            )
        self.reward_critic_tokenizer = self.tokenizer

        self.gold_model, self.gold_tokenizer = load_pretrained_models(
            self.args.gold_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'reward',
                'do_normalize': False,
            },
        )
        self.gold_model.set_normalize(False)

        if is_same_tokenizer(self.tokenizer, self.gold_tokenizer):
            self.gold_tokenizer = self.tokenizer

        self.cost_model, self.cost_tokenizer = load_pretrained_models(
            self.args.cost_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'reward',
                'do_normalize': False,
            },
        )
        self.cost_model.set_normalize(False)

        self.cost_critic_model, self.cost_critic_tokenizer = load_pretrained_models(
            self.args.reward_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'critic',
                'do_normalize': False,
            },
        )
        self.cost_critic_model.set_normalize(False)

        if is_same_tokenizer(self.tokenizer, self.cost_tokenizer):
            self.cost_tokenizer = self.tokenizer
        if not is_same_tokenizer(self.tokenizer, self.cost_critic_tokenizer):
            raise ValueError(
                (
                    'Cost critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--cost_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.cost_critic_tokenizer),
                    len(self.cost_critic_tokenizer),
                ),
            )
        self.cost_critic_tokenizer = self.tokenizer

    def init_datasets(self) -> None:
        """Initialize training and evaluation datasets."""
        if (
            self.args.per_device_prompt_batch_size
            * self.args.num_return_sequences
            % self.args.per_device_train_batch_size
            != 0
        ):
            raise ValueError(
                'The number of prompt-only samples must be divisible by the micro batch size.',
            )

        prompt_only_dataset = PromptOnlyDataset(
            self.args.train_datasets,
            tokenizer=self.tokenizer,
        )

        if self.args.need_eval:
            if self.args.eval_datasets is None and self.args.eval_split_ratio is not None:
                prompt_only_dataset, eval_dataset = prompt_only_dataset.split_train_test(
                    split_ratio=self.args.eval_split_ratio,
                )
            elif self.args.eval_datasets is not None and self.args.eval_split_ratio is None:
                eval_dataset = PromptOnlyDataset(
                    self.args.eval_datasets,
                    tokenizer=self.tokenizer,
                )
            else:
                raise ValueError('Either `eval_datasets` or `eval_split_ratio` should be provided.')

            self.eval_dataloader = DataLoader(
                eval_dataset,
                collate_fn=eval_dataset.get_collator(),
                sampler=DistributedSampler(eval_dataset, shuffle=True),
                batch_size=self.args.per_device_eval_batch_size,
            )
        else:
            self.eval_dataloader = None

        self.prompt_only_dataloader = DataLoader(
            prompt_only_dataset,
            collate_fn=prompt_only_dataset.get_collator(),
            sampler=DistributedSampler(prompt_only_dataset, shuffle=True),
            batch_size=self.args.per_device_prompt_batch_size,
        )

        self.use_ptx = self.args.ptx_datasets is not None
        if self.use_ptx:
            ptx_dataset = SupervisedDataset(
                self.args.ptx_datasets,
                tokenizer=self.tokenizer,
            )

            self.ptx_dataloader = DataLoader(
                ptx_dataset,
                collate_fn=ptx_dataset.get_collator(),
                sampler=DistributedSampler(ptx_dataset, shuffle=True),
                batch_size=self.args.per_device_prompt_batch_size * self.args.num_return_sequences,
            )
        else:
            self.ptx_dataloader = DataLoader(DummyDataset(len(self.prompt_only_dataloader)))

        self.args.total_training_steps = int(
            len(self.prompt_only_dataloader)
            * self.args.epochs
            * self.args.update_iters
            * self.args.per_device_prompt_batch_size
            * self.args.num_return_sequences
            // self.args.per_device_train_batch_size,
        )

    def _init_train_engine(
        self,
        model: nn.Module,
        weight_decay: float,
        lr: float,
        lr_scheduler_type: str,
        lr_warmup_ratio: float,
        total_training_steps: int,
        ds_config: dict[str, Any],
    ) -> deepspeed.DeepSpeedEngine:
        optimizer_grouped_parameters = get_optimizer_grouped_parameters(model, weight_decay)
        if (
            ds_config['zero_optimization'].get('offload_optimizer', {}).get('device', 'none')
            != 'none'
        ):
            optimizer = DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=lr, betas=ADAM_BETAS)
        else:
            optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=ADAM_BETAS)

        lr_scheduler_update_steps = total_training_steps // ds_config['gradient_accumulation_steps']
        num_warmup_steps = int(lr_scheduler_update_steps * lr_warmup_ratio)
        lr_scheduler = get_scheduler(
            name=lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=lr_scheduler_update_steps,
        )
        engine, *_ = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            config=ds_config,
        )
        return engine

    def _init_eval_engine(
        self,
        model: nn.Module,
        ds_config: dict[str, Any],
    ) -> deepspeed.DeepSpeedEngine:
        engine, *_ = deepspeed.initialize(
            model=model,
            config=ds_config,
        )
        return engine

    def init_engines(self) -> None:
        actor_ds_config = copy.deepcopy(self.ds_train_config)
        actor_total_training_steps = self.args.total_training_steps
        if self.use_ptx:
            self.args.gradient_accumulation_steps *= 2
            actor_ds_config['train_batch_size'] *= 2
            actor_ds_config['gradient_accumulation_steps'] *= 2
            actor_total_training_steps *= 2
        self.actor_model = self._init_train_engine(
            model=self.actor_model,
            weight_decay=self.args.actor_weight_decay,
            lr=self.args.actor_lr,
            lr_scheduler_type=self.args.actor_lr_scheduler_type,
            lr_warmup_ratio=self.args.actor_lr_warmup_ratio,
            total_training_steps=actor_total_training_steps,
            ds_config=actor_ds_config,
        )

        self.actor_reference_model = self._init_eval_engine(
            model=self.actor_reference_model,
            ds_config=self.ds_eval_config,
        )
        self.actor_reference_model.eval()

        self.reward_critic_model = self._init_train_engine(
            model=self.reward_critic_model,
            weight_decay=self.args.critic_weight_decay,
            lr=self.args.critic_lr,
            lr_scheduler_type=self.args.critic_lr_scheduler_type,
            lr_warmup_ratio=self.args.critic_lr_warmup_ratio,
            total_training_steps=self.args.total_training_steps,
            ds_config=self.ds_train_config,
        )

        if self.args.actor_gradient_checkpointing:
            self.actor_model.gradient_checkpointing_enable()
        if self.args.critic_gradient_checkpointing:
            self.reward_critic_model.gradient_checkpointing_enable()

        self.cost_critic_model = self._init_train_engine(
            model=self.cost_critic_model,
            weight_decay=self.args.critic_weight_decay,
            lr=self.args.critic_lr,
            lr_scheduler_type=self.args.critic_lr_scheduler_type,
            lr_warmup_ratio=self.args.critic_lr_warmup_ratio,
            total_training_steps=self.args.total_training_steps,
            ds_config=self.ds_train_config,
        )

        self.cost_model = self._init_eval_engine(
            model=self.cost_model,
            ds_config=self.ds_eval_config,
        )
        self.cost_model.eval()

        self.gold_model = self._init_eval_engine(
            model=self.gold_model,
            ds_config=self.ds_eval_config,
        )
        self.gold_model.eval()

        if self.args.critic_gradient_checkpointing:
            self.cost_critic_model.gradient_checkpointing_enable()

    def set_train(self, mode: bool = True) -> None:
        if mode:
            self.actor_model.train()
            self.reward_critic_model.train()
            self.cost_critic_model.train()

            if self.args.actor_gradient_checkpointing:
                self.actor_model.gradient_checkpointing_enable()
        else:
            self.actor_model.eval()
            self.reward_critic_model.eval()
            self.cost_critic_model.eval()

            if self.args.actor_gradient_checkpointing:
                self.actor_model.gradient_checkpointing_disable()

    def split_rl_micro_batches(
        self,
        prompt_only_batch: PromptOnlyBatch,
    ) -> list[PromptOnlyBatch]:
        """Split a batch of RL samples into micro-batches."""
        total_batch_size = prompt_only_batch['input_ids'].size(0)
        micro_batch_size = self.args.per_device_train_batch_size
        micro_batches = []
        for i in range(0, total_batch_size, micro_batch_size):
            micro_batch = optree.tree_map(
                # pylint: disable-next=cell-var-from-loop
                lambda tensor: tensor[i : i + micro_batch_size],  # noqa: B023
                prompt_only_batch,
            )
            micro_batches.extend(self.rollout(micro_batch))
        return micro_batches

    def split_ptx_micro_batches(
        self,
        ptx_batch: dict[str, torch.Tensor],
    ) -> list[dict[str, torch.Tensor]]:
        """Split a batch of PTX samples into micro-batches."""
        micro_batches = []
        total_batch_size = ptx_batch['input_ids'].size(0)
        micro_batch_size = self.args.per_device_train_batch_size
        for i in range(0, total_batch_size, micro_batch_size):
            micro_batch = optree.tree_map(
                # pylint: disable-next=cell-var-from-loop
                lambda tensor: tensor[i : i + micro_batch_size],  # noqa: B023
                ptx_batch,
            )
            micro_batches.append(micro_batch)
        return micro_batches

    @torch.no_grad()
    def rollout(self, prompt_only_batch: PromptOnlyBatch) -> list[dict[str, Any]]:
        """Rollout a batch of experiences."""
        input_ids = prompt_only_batch['input_ids']
        sequences = self.actor_model.module.generate(
            input_ids=input_ids,
            attention_mask=prompt_only_batch['attention_mask'],
            generation_config=self.generation_config,
            synced_gpus=True,
            do_sample=True,
        )
        sequences = (
            sequences.contiguous()
            .view(
                input_ids.size(0),
                self.args.num_return_sequences,
                -1,
            )
            .transpose(0, 1)
        )

        return [
            self.post_rollout(
                input_ids,
                seq,
                attention_mask=torch.logical_and(
                    seq.not_equal(self.tokenizer.pad_token_id),
                    seq.not_equal(self.tokenizer.unk_token_id),
                ),
            )
            for seq in sequences
        ]

    def train(self) -> None:
        """Train the model."""
        self.logger.print('***** Running training *****')

        progress_bar = tqdm(
            total=self.args.total_training_steps,
            desc=f'Training 1/{self.args.epochs} epoch',
            position=0,
            leave=True,
            disable=not is_main_process(),
        )

        if self.args.need_eval:
            self.logger.print('\n***** Evaluating at the beginning *****')
            self.logger.log(self.eval(), step=0)

        num_prompt_only_batches = len(self.prompt_only_dataloader)
        num_ptx_batches = len(self.ptx_dataloader)
        num_ptx_replicas = (num_prompt_only_batches + num_ptx_batches - 1) // num_ptx_batches
        for epoch in range(self.args.epochs):
            for prompt_only_batch, ptx_batch in zip(
                self.prompt_only_dataloader,
                itertools.chain.from_iterable([self.ptx_dataloader] * num_ptx_replicas),
            ):
                # generate batches
                self.set_eval()
                prompt_only_batch = to_device(prompt_only_batch, self.args.device)
                rl_batches = self.split_rl_micro_batches(prompt_only_batch)
                if self.use_ptx:
                    ptx_batch = to_device(ptx_batch, self.args.device)
                    ptx_batches = self.split_ptx_micro_batches(ptx_batch)
                else:
                    ptx_batches = [None for _ in range(len(rl_batches))]
                torch.cuda.empty_cache()

                # train
                self.set_train()
                for _ in range(self.args.update_iters):
                    for rl_batch, ptx_batch in zip(rl_batches, ptx_batches):
                        rl_info = self.rl_step(rl_batch)
                        torch.cuda.empty_cache()
                        self.logger.log(rl_info, step=self.global_step)
                        if self.use_ptx:
                            ptx_info = self.ptx_step(ptx_batch)
                            torch.cuda.empty_cache()
                            self.logger.log(ptx_info, step=self.global_step)

                        self.global_step += 1
                        progress_bar.set_description(
                            f'Training {epoch + 1}/{self.args.epochs} epoch '
                            f'(reward {rl_info["train/reward"]:.4f})',
                        )
                        progress_bar.update(1)

                        if self.global_step % self.args.save_interval == 0:
                            self.logger.print(f'Saving checkpoint at step {self.global_step} ...')
                            save_dir = os.path.join(self.args.output_dir, f'{self.global_step}')
                            os.makedirs(save_dir, exist_ok=True)
                            if self.args.save_16bit:
                                self.actor_model.save_16bit_model(save_dir)
                            else:
                                self.actor_model.save_checkpoint(save_dir)
                            self.logger.print('Checkpoint saved.')


                        if (
                            self.args.need_eval
                            and self.args.eval_strategy == 'steps'
                            and self.global_step % self.args.eval_interval == 0
                        ):
                            self.logger.print(
                                f'\n***** Evaluating at step {self.global_step} *****',
                            )
                            self.logger.log(self.eval(), step=self.global_step)

            if self.args.need_eval and self.args.eval_strategy == 'epoch':
                self.logger.print(
                    f'\n***** Evaluating at epoch {epoch + 1}/{self.args.epochs} *****',
                )
                self.logger.log(self.eval(), step=self.global_step)

    def eval(self) -> dict[str, Any]:
        """Evaluate the model on the evaluation dataset."""
        if self.eval_dataloader is None:
            return {}

        self.set_eval()
        prompts: list[str] = []
        generateds: list[str] = []
        scores: dict[str, list[torch.Tensor]] = {}

        eval_dataloader = tqdm(
            self.eval_dataloader,
            desc='Evaluating',
            disable=not is_main_process(),
        )

        for batch in eval_dataloader:
            batch = to_device(batch, self.args.device)
            with torch.no_grad():
                seq = self.actor_model.module.generate(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    max_length=self.args.max_length,
                    synced_gpus=True,
                    do_sample=True,
                )

            dist.barrier()

            attention_mask = torch.logical_and(
                seq.not_equal(self.tokenizer.pad_token_id),
                seq.not_equal(self.tokenizer.unk_token_id),
            )
            for key, values in self.eval_step(batch['input_ids'], seq, attention_mask).items():
                if key not in scores:
                    scores[key] = []
                scores[key].append(values)
            prompt = self.tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
            generated = self.tokenizer.batch_decode(seq, skip_special_tokens=True)
            generated = [text[len(prompt[i]) :] for i, text in enumerate(generated)]
            prompts.extend(prompt)
            generateds.extend(generated)

        # Display result in main process
        if is_main_process():
            columns = ['Prompt', 'Generated', *list(scores.keys())]
            concatenated_scores = {
                key: torch.cat(value, dim=0).to(torch.float32) for key, value in scores.items()
            }
            concatenated_scores = {
                key: value.tolist() for key, value in concatenated_scores.items()
            }
            rows = list(zip(prompts, generateds, *concatenated_scores.values()))
            self.logger.print_table(
                title='Evaluating...',
                columns=columns,
                rows=rows,
                max_num_rows=5,
            )

        # Gather results from all processes
        for key, values in scores.items():
            scores[key] = torch.cat(values, dim=0).mean()
            dist.reduce(scores[key], dst=0, op=dist.ReduceOp.AVG)
            scores[key] = scores[key].mean().item()
        dist.barrier()

        self.set_train()

        return scores

    def get_advantages_and_returns(
        self,
        values: torch.Tensor,
        rewards: torch.Tensor,
        sequence_mask: torch.BoolTensor,
        start: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute advantages and returns using Generalized Advantage Estimation (GAE)."""
        # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py
        last_gae_lambda = 0.0
        advantages_reversed = []
        values = values * sequence_mask
        rewards = rewards * sequence_mask
        length = rewards.size(-1)
        for t in reversed(range(start, length)):  # pylint: disable=invalid-name
            next_values = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * next_values - values[:, t]
            last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda
            advantages_reversed.append(last_gae_lambda)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values[:, start:]
        return advantages.detach(), returns

    def critic_loss_fn(
        self,
        values: torch.Tensor,  # size = (B, L - S)
        old_values: torch.Tensor,  # size = (B, L - S)
        returns: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        """Compute critic loss."""
        # size = (B, L - S)
        values_clipped = torch.clamp(
            values,
            old_values - self.clip_range_value,
            old_values + self.clip_range_value,
        )
        vf_loss1 = torch.square(values - returns)
        vf_loss2 = torch.square(values_clipped - returns)
        return 0.5 * masked_mean(torch.maximum(vf_loss1, vf_loss2), mask)  # size = ()

    def save(
        self,
        model: deepspeed.DeepSpeedEngine | PreTrainedModel | None = None,
        ds_config: dict | None = None,
    ) -> None:
        """Save model and tokenizer."""
        if model is None:
            model = self.actor_model
        if ds_config is None:
            ds_config = self.ds_train_config
        super().save(model=model, ds_config=ds_config)


    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                truncation=True,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_seq = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_seq = sequence
            cost_attention_mask = attention_mask

        if self.gold_tokenizer is not self.tokenizer:
            gold_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.gold_tokenizer,
                truncation=True,
                skip_special_tokens=True,
                device=self.args.device,
            )
            gold_input_ids = gold_tokenize_output['input_ids']
            gold_attention_mask = gold_tokenize_output['attention_mask']
        else:
            gold_input_ids = sequence
            gold_attention_mask = attention_mask
        gold_reward = self.gold_model(gold_input_ids, attention_mask=gold_attention_mask).end_scores

        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits

        cost = self.cost_model(cost_seq, attention_mask=cost_attention_mask).end_scores
        reward_values = self.reward_critic_model(sequence, attention_mask=attention_mask).scores
        cost_values = self.cost_critic_model(sequence, attention_mask=attention_mask).scores

        cost = cost.squeeze(dim=-1)
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]

        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        kl_divergence = log_probs - ref_log_probs

        self.episode_costs.extend(cost.tolist())

        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'reward': -kl_divergence * self.kl_coeff,
            'cost': cost,
            'gold': gold_reward,
            'reward_values': reward_values,
            'cost_values': cost_values,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

    @torch.no_grad()
    def eval_step(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                truncation=True,
                device=self.args.device,
            )
            cost_seq = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_seq = sequence
            cost_attention_mask = attention_mask

        if self.gold_tokenizer is not self.tokenizer:
            gold_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.gold_tokenizer,
                truncation=True,
                skip_special_tokens=True,
                device=self.args.device,
            )
            gold_input_ids = gold_tokenize_output['input_ids']
            gold_attention_mask = gold_tokenize_output['attention_mask']
        else:
            gold_input_ids = sequence
            gold_attention_mask = attention_mask

        reward = self.cost_model(
            input_ids=cost_seq,
            attention_mask=cost_attention_mask,
        ).end_scores.squeeze(dim=-1)

        gold = self.gold_model(
            gold_input_ids,
            attention_mask=gold_attention_mask,
        ).end_scores.squeeze(dim=-1)

        logits = self.actor_model(sequence, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        #
        ref_logits = self.actor_reference_model(
            sequence,
            attention_mask=attention_mask,
            use_cache=False,
        ).logits
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        start = prompt.size(-1) - 1
        mask = attention_mask[:, start+1:]
        mean_generated_length = mask.sum(dim=-1).float()
        kl_seq = ((log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1)
        return {
            'eval/reward': reward,
            'eval/gold_reward': gold,
            'eval/kl_divergence': kl_seq,
            'eval/mean_generated_length': mean_generated_length,
        }

    def reward_cost_shape(
        self,
        reward: torch.Tensor,  # size = (B, L)
        cost: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> tuple[torch.Tensor, torch.Tensor]:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        rewards = reward
        costs = torch.zeros_like(sequence_mask, dtype=cost.dtype, device=cost.device)
        costs[torch.arange(sequence_mask.size(0)), end_index] = cost
        return (
            torch.clamp(rewards, min=-self.clip_range_score, max=self.clip_range_score),
            torch.clamp(costs, min=-self.clip_range_score, max=self.clip_range_score),
        )

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        old_log_probs: torch.Tensor,  # size = (B, L - S)
        reward_advantages: torch.Tensor,  # size = (B, L - S)
        cost_advantages: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        multiplier = self.squash_fn(self.lagrange).item()

        # size = (B, L - S)
        # mix advantages of kl and reward model
        advantages = (1 - multiplier) * reward_advantages + multiplier * cost_advantages
        ratios = torch.exp(log_probs - old_log_probs)
        surrogate1 = advantages * ratios
        surrogate2 = advantages * torch.clamp(
            ratios,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        surrogate = torch.minimum(surrogate1, surrogate2)
        return -masked_mean(surrogate, mask)  # size = ()

    # pylint: disable-next=too-many-locals
    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        episode_cost = torch.tensor(self.episode_costs).mean().to(self.args.device)

        dist.reduce(episode_cost, dst=0, op=dist.ReduceOp.AVG)

        if is_main_process():
            lagrange_loss = (episode_cost - self.threshold) * self.squash_fn(self.lagrange)
            self.lagrange_optimizer.zero_grad()
            lagrange_loss.backward()
            torch.nn.utils.clip_grad_norm_([self.lagrange], self.args.max_grad_norm)
            self.lagrange_optimizer.step()
            if self.lagrange_max is not None:
                with torch.no_grad():
                    self.lagrange.clamp_(min=-self.lagrange_max, max=self.lagrange_max)

        dist.broadcast(self.lagrange, src=0)

        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        reward = rl_batch['reward'] # -kl (B, L)
        cost = rl_batch['cost'] # reward of reward model (B, )
        old_reward_values = rl_batch['reward_values']
        old_cost_values = rl_batch['cost_values']
        gold = rl_batch['gold'] # gold score (B, )
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():
            old_rewards, old_costs = self.reward_cost_shape(
                reward,
                cost,
                prompt,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )
            cost_advantages, cost_returns = self.get_advantages_and_returns(
                old_cost_values,
                old_costs,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            cost_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        cost_values = self.cost_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]
        cost_critic_loss = self.critic_loss_fn(
            cost_values[:, start:],
            old_cost_values[:, start:],
            cost_returns,
            sequence_mask[:, start:],
        )
        self.cost_critic_model.backward(cost_critic_loss)
        self.cost_critic_model.step()

        with torch.no_grad():
            mask = sequence_mask[:, start:]
            kl_divergence = ((old_log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1).mean()
            mean_generated_length = mask.sum(dim=-1).float().mean()
            max_generated_length = mask.sum(dim=-1).float().max()

            if self.args.save_train_data:
                prompt_ids = input_ids[:, :start+1].contiguous()
                response_ids = input_ids[:, start+1:].contiguous()
                gathered_prompt_ids = gather_input_ids(prompt_ids, self.tokenizer.pad_token_id)
                gathered_response_ids = gather_input_ids(response_ids, self.tokenizer.pad_token_id)
                gathered_gold = gather_scores(gold)
                gathered_reward = gather_scores(cost)
                prompt_texts = self.tokenizer.batch_decode(gathered_prompt_ids, skip_special_tokens=True)
                response_texts = self.tokenizer.batch_decode(gathered_response_ids, skip_special_tokens=True)
                if is_main_process():
                    assert len(prompt_texts) == len(response_texts) == len(gathered_gold) == len(gathered_reward)
                    self.train_generated_data.append({
                        'step': self.global_step,
                        'kl_divergence': kl_divergence.item(),
                        'mean_generated_length': mean_generated_length.item(),
                        'generated_data': [{
                            'prompt': prompt,
                            'response': response,
                            'gold': gold.item(),
                            'proxy': proxy.item(),
                        } for prompt, response, gold, proxy in zip(prompt_texts, response_texts, gathered_gold, gathered_reward)],
                    })

            cost = cost.mean()
            gold = gold.mean()

            reward_advantage = masked_mean(reward_advantages, mask)
            reward_return = masked_mean(reward_returns, mask)
            reward_value = masked_mean(reward_values[:, start:], mask)
            cost_advantage = masked_mean(cost_advantages, mask)
            cost_return = masked_mean(cost_returns, mask)
            cost_value = masked_mean(cost_values[:, start:], mask)

            actor_loss = get_all_reduce_mean(actor_loss)
            reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
            cost_critic_loss = get_all_reduce_mean(cost_critic_loss)
            kl_divergence = get_all_reduce_mean(kl_divergence)
            cost = get_all_reduce_mean(cost)
            gold = get_all_reduce_mean(gold)
            reward_advantage = get_all_reduce_mean(reward_advantage)
            reward_return = get_all_reduce_mean(reward_return)
            reward_value = get_all_reduce_mean(reward_value)
            cost_advantage = get_all_reduce_mean(cost_advantage)
            cost_return = get_all_reduce_mean(cost_return)
            cost_value = get_all_reduce_mean(cost_value)
            mean_generated_length = get_all_reduce_mean(mean_generated_length)
            max_generated_length = get_all_reduce_max(max_generated_length)

        dist.barrier()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/cost_critic_loss': cost_critic_loss.item(),
            'train/lagrange': self.lagrange.item(),
            'train/episode_cost': episode_cost.item(),
            'train/reward': cost.item(),
            'train/kl_divergence': kl_divergence.item(),
            'train/gold_reward': gold.item(),
            'train/reward_advantage': reward_advantage.item(),
            'train/reward_return': reward_return.item(),
            'train/reward_value': reward_value.item(),
            'train/cost_advantage': cost_advantage.item(),
            'train/cost_return': cost_return.item(),
            'train/cost_value': cost_value.item(),
            'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
            'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
            'train/cost_critic_lr': self.cost_critic_model.optimizer.param_groups[0]['lr'],
            'train/mean_generated_length': mean_generated_length.item(),
            'train/max_generated_length': max_generated_length.item(),
        }
