from __future__ import annotations

import argparse
from collections import deque
from typing import Any

import deepspeed
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase

from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.trainers import RLTrainer
from safe_rlhf.utils import (
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    get_all_reduce_mean,
    is_main_process,
    is_same_tokenizer,
)


class PPOLagTrainer(RLTrainer):
    TRAINING_TYPE = 'ppo_lag'

    cost_model: deepspeed.DeepSpeedEngine
    cost_critic_model: deepspeed.DeepSpeedEngine

    cost_tokenizer: PreTrainedTokenizerBase
    cost_critic_tokenizer: PreTrainedTokenizerBase

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        super().__init__(args=args, ds_train_config=ds_train_config, ds_eval_config=ds_eval_config)

        # Lagrange multiplier
        self.log_lambda = torch.nn.Parameter(
            torch.tensor(np.log(self.args.lambda_init), device=self.args.device),
            requires_grad=True,
        )
        self.log_lambda_max = np.log(self.args.lambda_max) if self.args.lambda_max else None
        self.log_lambda_optimizer = torch.optim.SGD([self.log_lambda], lr=self.args.lambda_lr)
        self.lambda_update_delay_steps = self.args.lambda_update_delay_steps
        self.episode_costs = deque(maxlen=self.args.episode_cost_window_size)
        self.threshold = self.args.threshold

    def init_models(self) -> None:
        super().init_models()
        self.cost_model, self.cost_tokenizer = load_pretrained_models(
            self.args.cost_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'cost',
                'do_normalize': self.args.normalize_cost,
            },
        )
        self.cost_model.set_normalize(self.args.normalize_cost)

        if self.args.cost_critic_model_name_or_path is None:
            self.args.cost_critic_model_name_or_path = self.args.cost_model_name_or_path
        self.cost_critic_model, self.cost_critic_tokenizer = load_pretrained_models(
            self.args.cost_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'critic',
                'do_normalize': False,
            },
        )
        self.cost_critic_model.set_normalize(False)

        if is_same_tokenizer(self.tokenizer, self.cost_tokenizer):
            self.cost_tokenizer = self.tokenizer
        if not is_same_tokenizer(self.tokenizer, self.cost_critic_tokenizer):
            raise ValueError(
                (
                    'Cost critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--cost_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.cost_critic_tokenizer),
                    len(self.cost_critic_tokenizer),
                ),
            )
        self.cost_critic_tokenizer = self.tokenizer

    def init_engines(self) -> None:
        super().init_engines()

        self.cost_critic_model = self._init_train_engine(
            model=self.cost_critic_model,
            weight_decay=self.args.critic_weight_decay,
            lr=self.args.critic_lr,
            lr_scheduler_type=self.args.critic_lr_scheduler_type,
            lr_warmup_ratio=self.args.critic_lr_warmup_ratio,
            total_training_steps=self.args.total_training_steps,
            ds_config=self.ds_train_config,
        )

        self.cost_model = self._init_eval_engine(
            model=self.cost_model,
            ds_config=self.ds_eval_config,
        )
        self.cost_model.eval()

        if self.args.critic_gradient_checkpointing:
            self.cost_critic_model.gradient_checkpointing_enable()

    def set_train(self, mode: bool = True) -> None:
        super().set_train(mode=mode)
        if mode:
            self.cost_critic_model.train()
        else:
            self.cost_critic_model.eval()

    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_seq = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_seq = sequence
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_seq = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_seq = sequence
            cost_attention_mask = attention_mask

        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits

        reward_score = self.reward_model(
            reward_seq,
            attention_mask=reward_attention_mask,
        ).end_scores
        cost_score = self.cost_model(
            cost_seq,
            attention_mask=cost_attention_mask,
        ).end_scores
        reward_value = self.reward_critic_model(
            sequence,
            attention_mask=attention_mask,
        ).scores
        cost_value = self.cost_critic_model(
            sequence,
            attention_mask=attention_mask,
        ).scores

        reward_score = reward_score.squeeze(dim=-1)
        cost_score = cost_score.squeeze(dim=-1)
        reward_value = reward_value.squeeze(dim=-1)[:, :-1]
        cost_value = cost_value.squeeze(dim=-1)[:, :-1]

        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        self.episode_costs.extend(cost_score.tolist())

        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'rewards': reward_score,
            'costs': cost_score,
            'reward_values': reward_value,
            'cost_values': cost_value,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

    @torch.no_grad()
    def eval_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_input_ids = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_input_ids = input_ids
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_input_ids = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_input_ids = input_ids
            cost_attention_mask = attention_mask

        reward_score = self.reward_model(
            reward_input_ids,
            attention_mask=reward_attention_mask,
        ).end_scores.squeeze(dim=-1)
        cost_score = self.cost_model(
            cost_input_ids,
            attention_mask=cost_attention_mask,
        ).end_scores.squeeze(dim=-1)
        return {
            'eval/reward': reward_score,
            'eval/cost': cost_score,
        }

    def add_kl_divergence_regularization(
        self,
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        reward_score: torch.Tensor,  # size = (B,)
        cost_score: torch.Tensor,  # size = (B,)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> tuple[torch.Tensor, torch.Tensor]:
        kl_divergence_estimate = -self.kl_coeff * (log_probs - ref_log_probs)  # size = (B, L)
        rewards = 0.5 * kl_divergence_estimate  # size = (B, L)
        costs = -0.5 * kl_divergence_estimate  # size = (B, L)
        reward_clip = torch.clamp(  # size = (B,)
            reward_score,
            min=-self.clip_range_score,
            max=self.clip_range_score,
        )
        cost_clip = torch.clamp(  # size = (B,)
            cost_score,
            min=-self.clip_range_score,
            max=self.clip_range_score,
        )
        batch_size = log_probs.size(0)
        for i in range(batch_size):
            end_index = sequence_mask[i].nonzero()[-1]
            rewards[i, end_index] += reward_clip[i]
            costs[i, end_index] += cost_clip[i]

        return rewards, costs

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        advantages: torch.Tensor,
        c_advantages: torch.Tensor,
        mask: torch.BoolTensor,
    ) -> torch.Tensor:
        # policy gradient loss
        multiplier = self.log_lambda.exp().item()
        advantages_all = (advantages - multiplier * c_advantages) / (1 + multiplier)

        ratio = torch.exp(log_probs - old_log_probs)
        pg_loss1 = -advantages_all * ratio
        pg_loss2 = -advantages_all * torch.clamp(
            ratio,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        return torch.sum(torch.maximum(pg_loss1, pg_loss2) * mask) / mask.sum()

    # pylint: disable-next=too-many-locals
    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        episode_costs = torch.tensor(self.episode_costs).mean().to(self.args.device)

        dist.reduce(episode_costs, dst=0, op=dist.ReduceOp.AVG)

        if is_main_process() and self.global_step >= self.lambda_update_delay_steps:
            lambda_loss = -(episode_costs - self.threshold) * self.log_lambda.exp()
            self.log_lambda_optimizer.zero_grad()
            lambda_loss.backward()
            self.log_lambda_optimizer.step()
            if self.log_lambda_max is not None:
                with torch.no_grad():
                    self.log_lambda.clamp_(max=self.log_lambda_max)

        dist.broadcast(self.log_lambda, src=0)

        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        rewards = rl_batch['rewards']
        costs = rl_batch['costs']
        old_reward_values = rl_batch['reward_values']
        old_cost_values = rl_batch['cost_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():
            old_rewards, old_costs = self.add_kl_divergence_regularization(
                prompt,
                old_log_probs,
                ref_log_probs,
                rewards,
                costs,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )
            cost_advantages, cost_returns = self.get_advantages_and_returns(
                old_cost_values,
                old_costs,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            cost_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        cost_values = self.cost_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]
        cost_critic_loss = self.critic_loss_fn(
            cost_values[:, start:],
            old_cost_values[:, start:],
            cost_returns,
            sequence_mask[:, start:],
        )
        self.cost_critic_model.backward(cost_critic_loss)
        self.cost_critic_model.step()

        with torch.no_grad():
            kl_divergence = (
                ((old_log_probs - ref_log_probs) * sequence_mask)[:, start:].sum(dim=-1).mean()
            )
            mean_generated_length = sequence_mask[:, start:].float().sum(dim=-1).mean()
            max_generated_length = sequence_mask[:, start:].float().sum(dim=-1).max()

        rewards = rewards.mean()
        costs = costs.mean()

        actor_loss = get_all_reduce_mean(actor_loss)
        reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
        cost_critic_loss = get_all_reduce_mean(cost_critic_loss)
        rewards = get_all_reduce_mean(rewards)
        costs = get_all_reduce_mean(costs)
        kl_divergence = get_all_reduce_mean(kl_divergence)
        mean_generated_length = get_all_reduce_mean(mean_generated_length)
        max_generated_length = get_all_reduce_max(max_generated_length)

        dist.barrier()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/cost_critic_loss': cost_critic_loss.item(),
            'train/lambda': self.log_lambda.exp().item(),
            'train/episode_costs': episode_costs.item(),
            'train/reward': rewards.item(),
            'train/cost': costs.item(),
            'train/kl_divergence': kl_divergence.item(),
            'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
            'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
            'train/cost_critic_lr': self.cost_critic_model.optimizer.param_groups[0]['lr'],
            'train/mean_generated_length': mean_generated_length.item(),
            'train/max_generated_length': max_generated_length.item(),
        }
