# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import argparse
from collections import deque
from typing import Any

import deepspeed
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase

from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.trainers import RLTrainer
from safe_rlhf.utils import (
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_mean,
    is_main_process,
    is_same_tokenizer,
)


class PPOLagTrainer(RLTrainer):
    TRAINING_TYPE = 'ppo_lag'

    cost_model: deepspeed.DeepSpeedEngine
    cost_critic_model: deepspeed.DeepSpeedEngine

    cost_tokenizer: PreTrainedTokenizerBase
    cost_critic_tokenizer: PreTrainedTokenizerBase

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        super().__init__(args=args, ds_train_config=ds_train_config, ds_eval_config=ds_eval_config)

        # Lagrange multiplier
        self.log_lambda = torch.nn.Parameter(
            torch.tensor(self.args.lambda_init).log().to(self.args.device).detach(),
            requires_grad=True,
        )
        self.log_lambda_optimizer = torch.optim.Adam([self.log_lambda], lr=self.args.lambda_lr)
        self.ep_costs = deque(maxlen=64)
        self.threshold = self.args.threshold

    def init_models(self) -> None:
        super().init_models()
        self.cost_model, self.cost_tokenizer = load_pretrained_models(
            self.args.cost_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
        )

        if self.args.cost_critic_model_name_or_path is None:
            self.args.cost_critic_model_name_or_path = self.args.cost_model_name_or_path
        self.cost_critic_model, self.cost_critic_tokenizer = load_pretrained_models(
            self.args.cost_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
        )

        if is_same_tokenizer(self.tokenizer, self.cost_tokenizer):
            self.cost_tokenizer = self.tokenizer
        if not is_same_tokenizer(self.tokenizer, self.cost_critic_tokenizer):
            raise ValueError(
                (
                    'Cost critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--cost_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.cost_critic_tokenizer),
                    len(self.cost_critic_tokenizer),
                ),
            )
        self.cost_critic_tokenizer = self.tokenizer

    def init_engines(self) -> None:
        super().init_engines()

        self.cost_critic_model = self._init_train_engine(
            self.cost_critic_model,
            self.args.critic_weight_decay,
            self.args.critic_lr,
        )
        self.cost_model = self._init_eval_engine(self.cost_model)
        self.cost_model.eval()

        if self.args.critic_gradient_checkpointing:
            self.cost_critic_model.gradient_checkpointing_enable()

    def set_train(self, mode: bool = True) -> None:
        super().set_train(mode=mode)
        if mode:
            self.cost_critic_model.train()
        else:
            self.cost_critic_model.eval()

    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_seq = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_seq = sequence
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_seq = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_seq = sequence
            cost_attention_mask = attention_mask

        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits

        reward_score = self.reward_model(
            reward_seq,
            attention_mask=reward_attention_mask,
        ).end_scores
        cost_score = self.cost_model(
            cost_seq,
            attention_mask=cost_attention_mask,
        ).end_scores
        reward_value = self.reward_critic_model(
            sequence,
            attention_mask=attention_mask,
        ).scores
        cost_value = self.cost_critic_model(
            sequence,
            attention_mask=attention_mask,
        ).scores

        reward_score = reward_score.squeeze(dim=-1)
        cost_score = cost_score.squeeze(dim=-1)
        reward_value = reward_value.squeeze(dim=-1)[:, :-1]
        cost_value = cost_value.squeeze(dim=-1)[:, :-1]

        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        self.ep_costs.extend(cost_score.tolist())

        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'rewards': reward_score,
            'costs': cost_score,
            'reward_values': reward_value,
            'cost_values': cost_value,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

    @torch.no_grad()
    def eval_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_input_ids = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_input_ids = input_ids
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_input_ids = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_input_ids = input_ids
            cost_attention_mask = attention_mask

        reward_score = self.reward_model(
            reward_input_ids,
            attention_mask=reward_attention_mask,
        ).end_scores.squeeze(dim=-1)
        cost_score = self.cost_model(
            cost_input_ids,
            attention_mask=cost_attention_mask,
        ).end_scores.squeeze(dim=-1)
        return {
            'eval/reward': reward_score,
            'eval/cost': cost_score,
        }

    def add_kl_divergence_regularization(
        self,
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        reward_score: torch.Tensor,  # size = (B,)
        cost_score: torch.Tensor,  # size = (B,)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        kl_divergence_estimate = -self.kl_coeff * (log_probs - ref_log_probs)  # size = (B, L)
        rewards = 0.5 * kl_divergence_estimate  # size = (B, L)
        costs = -0.5 * kl_divergence_estimate  # size = (B, L)
        reward_clip = torch.clamp(  # size = (B,)
            reward_score,
            min=-self.clip_range_score,
            max=self.clip_range_score,
        )
        cost_clip = torch.clamp(  # size = (B,)
            cost_score,
            min=-self.clip_range_score,
            max=self.clip_range_score,
        )
        batch_size = log_probs.size(0)
        for i in range(batch_size):
            end_index = sequence_mask[i].nonzero()[-1]
            rewards[i, end_index] += reward_clip[i]
            costs[i, end_index] += cost_clip[i]

        return rewards, costs

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        advantages: torch.Tensor,
        c_advantages: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> torch.Tensor:
        # policy gradient loss
        multiplier = self.log_lambda.exp().item()
        advantages_all = (advantages - multiplier * c_advantages) / (1 + multiplier)

        ratio = torch.exp(log_probs - old_log_probs)
        pg_loss1 = -advantages_all * ratio
        pg_loss2 = -advantages_all * torch.clamp(
            ratio,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        return torch.sum(torch.maximum(pg_loss1, pg_loss2) * mask) / mask.sum()

    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        ep_costs = torch.tensor(self.ep_costs).mean().to(self.args.device)

        dist.reduce(ep_costs, dst=0, op=dist.ReduceOp.AVG)

        if is_main_process():
            lambda_loss = -(ep_costs - self.threshold) * self.log_lambda.exp()
            self.log_lambda_optimizer.zero_grad()
            lambda_loss.backward()
            self.log_lambda_optimizer.step()

        dist.broadcast(self.log_lambda, src=0)

        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        rewards = rl_batch['rewards']
        costs = rl_batch['costs']
        old_reward_values = rl_batch['reward_values']
        old_cost_values = rl_batch['cost_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():
            old_rewards, old_costs = self.add_kl_divergence_regularization(
                prompt,
                old_log_probs,
                ref_log_probs,
                rewards,
                costs,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )
            cost_advantages, cost_returns = self.get_advantages_and_returns(
                old_cost_values,
                old_costs,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_prob = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_prob[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            cost_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        cost_values = self.cost_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]
        cost_critic_loss = self.critic_loss_fn(
            cost_values[:, start:],
            old_cost_values[:, start:],
            cost_returns,
            sequence_mask[:, start:],
        )
        self.cost_critic_model.backward(cost_critic_loss)
        self.cost_critic_model.step()

        with torch.no_grad():
            kl_divergence = (
                ((old_log_probs - ref_log_probs) * sequence_mask)[:, start:].sum(dim=-1).mean()
            )

        rewards = rewards.mean()
        costs = costs.mean()

        actor_loss = get_all_reduce_mean(actor_loss)
        reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
        cost_critic_loss = get_all_reduce_mean(cost_critic_loss)
        rewards = get_all_reduce_mean(rewards)
        costs = get_all_reduce_mean(costs)
        kl_divergence = get_all_reduce_mean(kl_divergence)

        dist.barrier()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/cost_critic_loss': cost_critic_loss.item(),
            'train/lambda': self.log_lambda.exp().item(),
            'train/episode_costs': ep_costs.item(),
            'train/reward': rewards.item(),
            'train/cost': costs.item(),
            'train/kl_divergence': kl_divergence.item(),
        }
