

from __future__ import annotations

import os
from typing import Any

import argparse
import deepspeed
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, PreTrainedTokenizer

from safe_rlhf.models import AutoModelForScore, AutoModelForScoreLM, load_pretrained_models
from safe_rlhf.models.score_lm import ScoreLMOutput
from safe_rlhf.models.score_model import ScoreModelOutput
from safe_rlhf.trainers import RLTrainer
from safe_rlhf.utils import (
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    get_all_reduce_min,
    get_all_reduce_mean,
    is_same_tokenizer,
    masked_mean,
    masked_min,
    gather_input_ids,
    gather_scores,
    is_main_process,
)


class BehaviorSupportedTrainer(RLTrainer):
    TRAINING_TYPE = 'value_supported_ppo'
    train_generated_data = []

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        super().__init__(args, ds_train_config, ds_eval_config)

        # Supported Value PPO
        self.use_supported_value = self.args.use_supported_value
        self.unsupported_value = self.args.unsupported_value

        # Save rollout samples for analysis
        self.interval_to_save_rollout = self.args.interval_to_save_rollout
        self.flag_to_save_rollout = self.interval_to_save_rollout > 0
        self.file_to_save_rollout = os.path.join(
            self.args.output_dir,
            f'rollout_samples_{self.args.local_rank}.txt',
        )
        self.file_to_save_rollout = (
            None if not self.flag_to_save_rollout else open(self.file_to_save_rollout, 'w')
        )

    def init_models(
        self,
        language_model_type: type[AutoModelForCausalLM] = AutoModelForCausalLM,
        reward_model_type: type[AutoModelForScore | AutoModelForScoreLM] = AutoModelForScoreLM,
        reward_critic_type: type[AutoModelForScore] = AutoModelForScore,
    ) -> None:
        super().init_models(
            language_model_type=language_model_type,
            reward_model_type=reward_model_type,
            reward_critic_type=reward_critic_type,
        )

        self.use_behavior_model = self.args.behavior_model_name_or_path is not None
        if self.use_behavior_model:
            self.behavior_model, self.behavior_tokenizer = load_pretrained_models(
                self.args.behavior_model_name_or_path,
                model_max_length=self.args.max_length,
                padding_side='right',
                auto_model_type=AutoModelForCausalLM,
                trust_remote_code=self.args.trust_remote_code,
            )

        if (
            self.args.behavior_model_name_or_path is None
            and not is_same_tokenizer(self.tokenizer, self.reward_tokenizer)
        ) or (
            self.args.behavior_model_name_or_path is not None
            and not is_same_tokenizer(self.tokenizer, self.behavior_tokenizer)
        ):
            raise ValueError(
                'Value Supported PPO requires the same tokenizer for the language model and the '
                'behavior model.'
            )

        self.gold_model, self.gold_tokenizer = load_pretrained_models(
            self.args.gold_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'reward',
                'do_normalize': self.args.normalize_gold,
            },
        )
        self.gold_model.set_normalize(self.args.normalize_gold)

        if is_same_tokenizer(self.tokenizer, self.gold_tokenizer):
            self.gold_tokenizer = self.tokenizer

    def init_engines(self) -> None:
        super().init_engines()

        if self.use_behavior_model:
            self.behavior_model = self._init_eval_engine(
                model=self.behavior_model,
                ds_config=self.ds_eval_config,
            )
            self.behavior_model.eval()

        self.gold_model = self._init_eval_engine(
            model=self.gold_model,
            ds_config=self.ds_eval_config,
        )
        self.gold_model.eval()

    def get_reward(
        self,
        reward_model: deepspeed.DeepSpeedEngine,
        src_tokenizer: PreTrainedTokenizer,
        dest_tokenizer: PreTrainedTokenizer,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
        get_behavior_log_probs: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:

        if get_behavior_log_probs and self.use_behavior_model:
            behavior_output = self.behavior_model(sequence, attention_mask=attention_mask)
            behavior_logits = behavior_output.logits
            behavior_log_probs = gather_log_probabilities(behavior_logits[:, :-1], sequence[:, 1:])

        if not is_same_tokenizer(src_tokenizer, dest_tokenizer):
            tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=src_tokenizer,
                dest_tokenizer=dest_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            sequence = tokenize_output['input_ids']
            attention_mask = tokenize_output['attention_mask']
            # for gpt2
            if 'GPT2' in reward_model.module.config.architectures[0]:
                sequence = sequence[:, :1024]
                attention_mask = attention_mask[:, :1024]

        reward_output = reward_model(sequence, attention_mask=attention_mask)

        if get_behavior_log_probs:
            if not self.use_behavior_model:
                if not isinstance(reward_output, ScoreLMOutput):
                    raise ValueError(
                        f'get_behavior_log_probs is True, but reward_model is not a ScoreLM model'
                        f'and behavior_model is not provided.'
                    )

                behavior_logits = reward_output.logits
                behavior_log_probs = gather_log_probabilities(
                    behavior_logits[:, :-1], sequence[:, 1:]
                )
            return {
                'reward': reward_output.end_scores.squeeze(dim=-1),
                'behavior_log_probs': behavior_log_probs,
            }
        else:
            return {
                'reward': reward_output.end_scores.squeeze(dim=-1),
            }

    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        reward_output = self.get_reward(
            reward_model=self.reward_model,
            src_tokenizer=self.tokenizer,
            dest_tokenizer=self.reward_tokenizer,
            sequence=sequence,
            attention_mask=attention_mask,
            get_behavior_log_probs=True,
        )
        reward, behavior_log_probs = reward_output['reward'], reward_output['behavior_log_probs']
        if (
            log_probs.size() != behavior_log_probs.size() and self.args.use_supported_value
        ):  # TODO: add behavior model
            raise ValueError(
                f'Value Supported PPO requires the same sequence length for log_probs and'
                f'behavior_log_probs, but {log_probs.size()} != {behavior_log_probs.size()}'
            )

        gold_reward = self.get_reward(
            reward_model=self.gold_model,
            src_tokenizer=self.tokenizer,
            dest_tokenizer=self.gold_tokenizer,
            sequence=sequence,
            attention_mask=attention_mask,
            get_behavior_log_probs=False,
        )['reward']

        reward_values = self.reward_critic_model(sequence, attention_mask=attention_mask).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]

        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'reward': reward,
            'behavior_log_probs': behavior_log_probs,
            'gold_reward': gold_reward,
            'reward_values': reward_values,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

    @torch.no_grad()
    def eval_step(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        reward_output = self.get_reward(
            reward_model=self.reward_model,
            src_tokenizer=self.tokenizer,
            dest_tokenizer=self.reward_tokenizer,
            sequence=sequence,
            attention_mask=attention_mask,
            get_behavior_log_probs=True,
        )
        reward, behavior_log_probs = reward_output['reward'], reward_output['behavior_log_probs']

        start = prompt.size(-1) - 1
        response_mask = attention_mask[:, 1:][:, start:]
        behavior_probs = torch.exp(behavior_log_probs[:, start:])
        behavior_probs = (behavior_probs * response_mask).sum(dim=-1) / response_mask.sum(dim=-1)

        gold_reward = self.get_reward(
            reward_model=self.gold_model,
            src_tokenizer=self.tokenizer,
            dest_tokenizer=self.gold_tokenizer,
            sequence=sequence,
            attention_mask=attention_mask,
            get_behavior_log_probs=False,
        )['reward']

        return {
            'eval/reward': reward,
            'eval/behavior_probs': behavior_probs,
            'eval/gold_reward': gold_reward,
        }

    def add_kl_divergence_regularization(
        self,
        reward: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> torch.Tensor:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate
        rewards = torch.scatter_add(
            kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=reward.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        return torch.clamp(rewards, min=-self.clip_range_score, max=self.clip_range_score)

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        old_log_probs: torch.Tensor,  # size = (B, L - S)
        advantages: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        # size = (B, L - S)
        ratios = torch.exp(log_probs - old_log_probs)
        surrogate1 = advantages * ratios
        surrogate2 = advantages * torch.clamp(
            ratios,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        surrogate = torch.minimum(surrogate1, surrogate2)
        return -masked_mean(surrogate, mask)  # size = ()

    def get_supported_advantages_and_returns(
        self,
        values: torch.Tensor,
        rewards: torch.Tensor,
        behavior_log_probs: torch.Tensor,
        sequence_mask: torch.BoolTensor,
        start: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute advantages and returns using Generalized Advantage Estimation (GAE)."""
        # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py
        last_gae_lambda = 0.0
        advantages_reversed = []
        values = values * sequence_mask
        rewards = rewards * sequence_mask
        length = rewards.size(-1)
        for t in reversed(range(start, length)):  # pylint: disable=invalid-name
            next_values = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * next_values - values[:, t]
            last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda
            advantages_reversed.append(last_gae_lambda)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values[:, start:]

        if self.use_supported_value:
            behavior_probs = torch.exp(behavior_log_probs[:, start:])
            unsupported_index = (behavior_probs < self.args.eps) * sequence_mask[:, start:]
            unsupported_index = torch.cat(
                [
                    torch.zeros((unsupported_index.size(0), 1), dtype=torch.bool).to(
                        unsupported_index.device
                    ),
                    unsupported_index[:, :-1],
                ],
                dim=1,
            )
            returns = torch.where(
                unsupported_index,
                torch.full_like(returns, self.unsupported_value),
                returns,
            )

        return advantages.detach(), returns

    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        reward = rl_batch['reward']
        behavior_log_probs = rl_batch['behavior_log_probs']
        gold_reward = rl_batch['gold_reward']
        old_reward_values = rl_batch['reward_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():
            modified_reward = reward
            old_rewards = self.add_kl_divergence_regularization(
                modified_reward,
                prompt,
                old_log_probs,
                ref_log_probs,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_supported_advantages_and_returns(
                old_reward_values,
                old_rewards,
                behavior_log_probs,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]

        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        with torch.no_grad():
            mask = sequence_mask[:, start:]
            kl_divergence = ((old_log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1).mean()
            mean_generated_length = mask.sum(dim=-1).float().mean()
            max_generated_length = mask.sum(dim=-1).float().max()

            if self.args.save_train_data:
                prompt_ids = input_ids[:, : start + 1].contiguous()
                response_ids = input_ids[:, start + 1 :].contiguous()
                gathered_prompt_ids = gather_input_ids(prompt_ids, self.tokenizer.pad_token_id)
                gathered_response_ids = gather_input_ids(response_ids, self.tokenizer.pad_token_id)
                gathered_gold = gather_scores(gold_reward)
                gathered_reward = gather_scores(reward)
                behavior_probs = torch.exp(behavior_log_probs[:, start:])
                num_unsupported_actions = ((behavior_probs < self.args.eps).float() * mask).sum(
                    dim=-1
                )
                gathered_num_unsupported_actions = gather_scores(num_unsupported_actions)

                prompt_texts = self.tokenizer.batch_decode(
                    gathered_prompt_ids, skip_special_tokens=True
                )
                response_texts = self.tokenizer.batch_decode(
                    gathered_response_ids, skip_special_tokens=True
                )

                if is_main_process():
                    assert (
                        len(prompt_texts)
                        == len(response_texts)
                        == len(gathered_gold)
                        == len(gathered_reward)
                    )
                    self.train_generated_data.append(
                        {
                            'step': self.global_step,
                            'kl_divergence': kl_divergence.item(),
                            'mean_generated_length': mean_generated_length.item(),
                            'generated_data': [
                                {
                                    'prompt': prompt,
                                    'response': response,
                                    'gold': gold.item(),
                                    'proxy': proxy.item(),
                                    'num_unsupported_actions': num.item(),
                                }
                                for prompt, response, gold, proxy, num in zip(
                                    prompt_texts,
                                    response_texts,
                                    gathered_gold,
                                    gathered_reward,
                                    gathered_num_unsupported_actions,
                                )
                            ],
                        }
                    )

            reward = reward.mean()
            gold_reward = gold_reward.mean()
            modified_reward = modified_reward.mean()
            reward_with_kl_penalty = (old_rewards[:, start:] * mask).sum(dim=-1).mean()
            reward_advantage = masked_mean(reward_advantages, mask)
            reward_return = masked_mean(reward_returns, mask)

            min_reward_value = masked_min(reward_values[:, start:], mask)
            mean_reward_value = masked_mean(reward_values[:, start:], mask)

            actor_loss = get_all_reduce_mean(actor_loss)
            reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
            reward = get_all_reduce_mean(reward)
            gold_reward = get_all_reduce_mean(gold_reward)
            modified_reward = get_all_reduce_mean(modified_reward)
            reward_with_kl_penalty = get_all_reduce_mean(reward_with_kl_penalty)
            reward_advantage = get_all_reduce_mean(reward_advantage)
            reward_return = get_all_reduce_mean(reward_return)

            min_reward_value = get_all_reduce_min(min_reward_value)
            mean_reward_value = get_all_reduce_mean(mean_reward_value)

            kl_divergence = get_all_reduce_mean(kl_divergence)
            mean_generated_length = get_all_reduce_mean(mean_generated_length)
            max_generated_length = get_all_reduce_max(max_generated_length)

            behavior_probs = torch.exp(behavior_log_probs[:, start:])
            num_unsupported_actions = (
                ((behavior_probs < self.args.eps).float() * mask).sum(dim=-1).mean()
            )
            num_unsupported_actions = get_all_reduce_mean(num_unsupported_actions)
            behavior_probs = masked_mean(behavior_probs, mask)
            behavior_probs = get_all_reduce_mean(behavior_probs)
        dist.barrier()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/reward': reward.item(),
            'train/gold_reward': gold_reward.item(),
            'train/modified_reward': modified_reward.item(),
            'train/reward_with_kl_penalty': reward_with_kl_penalty.item(),
            'train/reward_advantage': reward_advantage.item(),
            'train/reward_return': reward_return.item(),
            'train/mean_reward_value': mean_reward_value.item(),
            'train/min_reward_value': min_reward_value.item(),
            'train/kl_divergence': kl_divergence.item(),
            'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
            'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
            'train/mean_generated_length': mean_generated_length.item(),
            'train/max_generated_length': max_generated_length.item(),
            'train/num_unsupported_actions': num_unsupported_actions.item(),
            'train/behavior_probs': behavior_probs.item(),
        }
