from __future__ import annotations

import argparse
from collections import deque
from typing import Any

import deepspeed
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase

from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.trainers import RLTrainer
from safe_rlhf.utils import (
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    get_all_reduce_mean,
    get_all_reduce_sum,
    is_main_process,
    is_same_tokenizer,
    masked_mean,
)


class RePOTrainer(RLTrainer):
    TRAINING_TYPE = 'repo'

    cost_model: deepspeed.DeepSpeedEngine
    cost_critic_model: deepspeed.DeepSpeedEngine

    cost_tokenizer: PreTrainedTokenizerBase
    cost_critic_tokenizer: PreTrainedTokenizerBase

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        super().__init__(args=args, ds_train_config=ds_train_config, ds_eval_config=ds_eval_config)

        # Lagrange multiplier
        self.reward_scale = args.reward_scale
        self.lambda_used_log = self.args.lambda_used_log
        if self.lambda_used_log:
            self.log_lambda = torch.nn.Parameter(
                torch.tensor(np.log(self.args.lambda_init), device=self.args.device),
                requires_grad=True,
            )
            self.log_lambda_optimizer = torch.optim.SGD([self.log_lambda], lr=self.args.lambda_lr)
        else:
            self._lambda = torch.tensor(self.args.lambda_init, device=self.args.device)
            self.lambda_lr = self.args.lambda_lr
        self.kl_coeff_target = 6.0
        self.log_lambda_max = np.log(self.args.lambda_max) if self.args.lambda_max else None
        
        self.lambda_update_delay_steps = self.args.lambda_update_delay_steps
        self.episode_costs = deque(maxlen=self.args.episode_cost_window_size)
        self.threshold = self.args.threshold
        self.output_attentions = False
        self.redist_factor = 0.8

    def init_models(self) -> None:
        super().init_models()
        self.cost_model, self.cost_tokenizer = load_pretrained_models(
            self.args.cost_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'cost',
                'do_normalize': self.args.normalize_cost,
            },
        )
        self.cost_model.set_normalize(self.args.normalize_cost)

        if self.args.cost_critic_model_name_or_path is None:
            self.args.cost_critic_model_name_or_path = self.args.cost_model_name_or_path
        self.cost_critic_model, self.cost_critic_tokenizer = load_pretrained_models(
            self.args.cost_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'critic',
                'do_normalize': False,
            },
        )
        self.cost_critic_model.set_normalize(False)

        if is_same_tokenizer(self.tokenizer, self.cost_tokenizer):
            self.cost_tokenizer = self.tokenizer
        if not is_same_tokenizer(self.tokenizer, self.cost_critic_tokenizer):
            raise ValueError(
                (
                    'Cost critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--cost_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.cost_critic_tokenizer),
                    len(self.cost_critic_tokenizer),
                ),
            )
        self.cost_critic_tokenizer = self.tokenizer

    def init_engines(self) -> None:
        super().init_engines()

        self.cost_critic_model = self._init_train_engine(
            model=self.cost_critic_model,
            weight_decay=self.args.critic_weight_decay,
            lr=self.args.critic_lr,
            lr_scheduler_type=self.args.critic_lr_scheduler_type,
            lr_warmup_ratio=self.args.critic_lr_warmup_ratio,
            total_training_steps=self.args.total_training_steps,
            ds_config=self.ds_train_config,
        )

        self.cost_model = self._init_eval_engine(
            model=self.cost_model,
            ds_config=self.ds_eval_config,
        )
        self.cost_model.eval()

        if self.args.critic_gradient_checkpointing:
            self.cost_critic_model.gradient_checkpointing_enable()

    def set_train(self, mode: bool = True) -> None:
        super().set_train(mode=mode)
        if mode:
            self.cost_critic_model.train()
        else:
            self.cost_critic_model.eval()

    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        # if is_main_process():
        #     print("generating process.....")
        #     print("[DEBUG] input id",  prompt[0])
        #     print("[DEBUG] sequence", sequence[0])
        #     print("[DEBUG] mask", attention_mask[0])
        #     print("============================================")
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_seq = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_seq = sequence
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_seq = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_seq = sequence
            cost_attention_mask = attention_mask
        start = prompt.size(-1) - 1
        mask = attention_mask[:, start:]
        length = mask.sum(dim=-1).float().reshape(-1, 1)
        # if sequence.shape[1] >= 512:
        #     print("[DEBUG] output sequence length is too long")
        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits

        reward_outputs = self.reward_model(reward_seq, attention_mask=reward_attention_mask, output_attentions=self.output_attentions)
        reward = reward_outputs.end_scores
        rewards_attentions = reward_outputs.attentions
        # if is_main_process():
        #     print("[DEBUG] length : ", length)
        #     print("[DEBUG] reward : ", reward)
        cost_outputs = self.cost_model(cost_seq, attention_mask=cost_attention_mask, output_attentions=self.output_attentions)
        cost = cost_outputs.end_scores
        cost_attentions = cost_outputs.attentions
        reward_values = self.reward_critic_model(sequence, attention_mask=attention_mask).scores
        cost_values = self.cost_critic_model(sequence, attention_mask=attention_mask).scores

        reward = reward.squeeze(dim=-1)
        cost = cost.squeeze(dim=-1)
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]

        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        self.episode_costs.extend(cost.tolist())
        # if is_main_process():
        #     print("[DEBUG] reward_value shape : ", reward_values[0].shape)
        #     print("[DEBUG] cost_value shape : ", cost_values[0].shape)
        #     print("[DEBUG] sequence shape : ", sequence[0].shape)
        #     print("[DEBUG] mask shape : ", attention_mask[0].shape)
        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'reward': reward,
            'cost': cost,
            'reward_values': reward_values,
            'cost_values': cost_values,
            'reward_attentions': rewards_attentions,
            'cost_attentions': cost_attentions,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

    @torch.no_grad()
    def eval_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_input_ids = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_input_ids = input_ids
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_input_ids = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_input_ids = input_ids
            cost_attention_mask = attention_mask

        reward = self.reward_model(
            reward_input_ids,
            attention_mask=reward_attention_mask,
        ).end_scores.squeeze(dim=-1)
        cost = self.cost_model(
            cost_input_ids,
            attention_mask=cost_attention_mask,
        ).end_scores.squeeze(dim=-1)
        return {
            'eval/reward': reward,
            'eval/cost': cost,
        }
        
    def get_kl_divergence_regularization(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        ref_log_probs: torch.Tensor,  # size = (B, L - S)
    ):
        # size = (B, L)
        kl_divergence = log_probs - ref_log_probs
        return kl_divergence


    def add_kl_divergence_regularization(
        self,
        reward: torch.Tensor,  # size = (B,)
        cost: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> tuple[torch.Tensor, torch.Tensor]:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)
        length = sequence_mask.size(-1)
        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate
        
        rewards = torch.scatter_add(
            1.0  * kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=reward.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        costs = torch.scatter_add(
            -1.0 * kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=cost.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        
        
        
        return (
            torch.clamp(rewards, min=-self.clip_range_score, max=self.clip_range_score),
            torch.clamp(costs, min=-self.clip_range_score, max=self.clip_range_score),
        )

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        old_log_probs: torch.Tensor,  # size = (B, L - S)
        ref_log_probs: torch.Tensor,  # size = (B, L - S)
        reward_advantages: torch.Tensor,  # size = (B, L - S)
        cost_advantages: torch.Tensor,  # size = (B, L - S)
        reward_returns: torch.Tensor,
        cost_returns: torch.Tensor,
        mask: torch.BoolTensor,  # size = (B, L - S)
        old_reward: torch.Tensor,
        old_cost: torch.Tensor,
        reward: torch.Tensor,
        cost: torch.Tensor,
        reward_values: torch.Tensor,
        cost_values: torch.Tensor,
        kl_divergence: torch.Tensor,
    ) -> torch.Tensor:  # size = ()
        if self.lambda_used_log:
            multiplier = self.log_lambda.exp().item()
        else:
            multiplier = self._lambda.item()
            
        
        length = mask.size(-1)  
        unsafe_tag = torch.where(cost > self.threshold, torch.ones_like(cost), torch.zeros_like(cost))
        unsafe_tag = unsafe_tag.unsqueeze(1).repeat(1, length)
        safe_tag = torch.where(cost > self.threshold, torch.zeros_like(cost), torch.ones_like(cost))
        safe_tag = safe_tag.unsqueeze(1).repeat(1, length)
        
        
        advantages = (reward_advantages - multiplier * (cost_advantages * unsafe_tag - reward_advantages * safe_tag)) / (1 + multiplier)
            
        ppo_ratios = torch.exp(log_probs - old_log_probs)
        surrogate1 = advantages * ppo_ratios
        surrogate2 = advantages * torch.clamp(
            ppo_ratios,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        surrogate = torch.minimum(surrogate1, surrogate2)
        loss = -masked_mean(surrogate, mask)
        return loss


    # pylint: disable-next=too-many-locals
    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        
        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        reward = rl_batch['reward']
        cost = rl_batch['cost']
        old_reward_values = rl_batch['reward_values']
        old_cost_values = rl_batch['cost_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']
        
        episode_cost = torch.tensor(self.episode_costs).mean().to(self.args.device)
        cost_mean = torch.clamp(cost, min=self.threshold).mean() - self.threshold
        dist.reduce(cost_mean, dst=0, op=dist.ReduceOp.AVG)
        dist.reduce(episode_cost, dst=0, op=dist.ReduceOp.AVG)
        if self.lambda_used_log:
            if is_main_process() and self.global_step >= self.lambda_update_delay_steps:
                lambda_loss = -cost_mean * self.log_lambda.exp()
                self.log_lambda_optimizer.zero_grad()
                lambda_loss.backward()
                self.log_lambda_optimizer.step()
                if self.log_lambda_max is not None:
                    with torch.no_grad():
                        self.log_lambda.clamp_(max=self.log_lambda_max)
            dist.broadcast(self.log_lambda, src=0)
        else:          
            if is_main_process() and self.global_step >= self.lambda_update_delay_steps:
                self._lambda = self.lambda_lr * torch.clamp_(cost_mean - self.threshold, min=0) + self._lambda * (1 - self._lambda)
                if self.log_lambda_max is not None:
                    with torch.no_grad():
                        self._lambda.clamp_(max=self.log_lambda_max)
            dist.broadcast(self._lambda, src=0)
                    

        


        # length = attention_mask.size(-1)
        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]
        # unsafe_tag = torch.where(cost > self.threshold, torch.ones_like(cost), torch.zeros_like(cost))
        with torch.no_grad():
            old_rewards, old_costs = self.add_kl_divergence_regularization(
                reward,
                cost,
                prompt,
                old_log_probs,
                ref_log_probs,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )
            # unsafe_tag = unsafe_tag.unsqueeze(1).repeat(1, length)
            cost_advantages, cost_returns = self.get_advantages_and_returns(
                old_cost_values,
                old_costs,
                sequence_mask,
                start,
            )
            kl_divergence = self.get_kl_divergence_regularization(old_log_probs, ref_log_probs)
        
        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            ref_log_probs[:, start:],
            reward_advantages * self.reward_scale,
            cost_advantages,
            reward_returns,
            cost_returns,
            sequence_mask[:, start:],
            old_rewards[:, start:],
            old_costs[:, start:],
            reward,
            cost,
            old_reward_values[:, start:],
            old_cost_values[:, start:],
            kl_divergence[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        cost_values = self.cost_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]
        cost_critic_loss = self.critic_loss_fn(
            cost_values[:, start:],
            old_cost_values[:, start:],
            cost_returns,
            sequence_mask[:, start:],
        )
        self.cost_critic_model.backward(cost_critic_loss)
        self.cost_critic_model.step()
        cost_tag = torch.where(cost > self.threshold, torch.ones_like(cost), torch.zeros_like(cost))
        with torch.no_grad():
            mask = sequence_mask[:, start:]
            kl_divergence = ((old_log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1).mean()
            mean_generated_length = mask.sum(dim=-1).float().mean()
            max_generated_length = mask.sum(dim=-1).float().max()

            reward = reward.mean()
            cost = cost.mean()
            unsafe_num = cost_tag.sum()
            reward_with_kl_penalty = (old_rewards[:, start:] * mask).sum(dim=-1).mean()
            reward_advantage = masked_mean(reward_advantages, mask)
            reward_return = masked_mean(reward_returns, mask)
            reward_value = masked_mean(reward_values[:, start:], mask)
            cost_with_kl_penalty = (old_costs[:, start:] * mask).sum(dim=-1).mean()
            cost_advantage = masked_mean(cost_advantages, mask)
            cost_return = masked_mean(cost_returns, mask)
            cost_value = masked_mean(cost_values[:, start:], mask)

            actor_loss = get_all_reduce_mean(actor_loss)
            unsafe_num = get_all_reduce_sum(unsafe_num)
            reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
            cost_critic_loss = get_all_reduce_mean(cost_critic_loss)
            reward = get_all_reduce_mean(reward)
            cost = get_all_reduce_mean(cost)
            reward_with_kl_penalty = get_all_reduce_mean(reward_with_kl_penalty)
            reward_advantage = get_all_reduce_mean(reward_advantage)
            reward_return = get_all_reduce_mean(reward_return)
            reward_value = get_all_reduce_mean(reward_value)
            cost_with_kl_penalty = get_all_reduce_mean(cost_with_kl_penalty)
            cost_advantage = get_all_reduce_mean(cost_advantage)
            cost_return = get_all_reduce_mean(cost_return)
            cost_value = get_all_reduce_mean(cost_value)
            kl_divergence = get_all_reduce_mean(kl_divergence)
            mean_generated_length = get_all_reduce_mean(mean_generated_length)
            max_generated_length = get_all_reduce_max(max_generated_length)

        dist.barrier()
        if self.lambda_used_log:
            return {
                'train/actor_loss': actor_loss.item(),
                'train/reward_critic_loss': reward_critic_loss.item(),
                'train/cost_critic_loss': cost_critic_loss.item(),
                'train/lambda': self.log_lambda.exp().item(),
                'train/episode_cost': episode_cost.item(),
                'train/reward': reward.item(),
                'train/cost': cost.item(),
                'train/reward_with_kl_penalty': reward_with_kl_penalty.item(),
                'train/reward_advantage': reward_advantage.item(),
                'train/reward_return': reward_return.item(),
                'train/reward_value': reward_value.item(),
                'train/cost_with_kl_penalty': cost_with_kl_penalty.item(),
                'train/cost_advantage': cost_advantage.item(),
                'train/cost_return': cost_return.item(),
                'train/cost_value': cost_value.item(),
                'train/kl_divergence': kl_divergence.item(),
                'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
                'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
                'train/cost_critic_lr': self.cost_critic_model.optimizer.param_groups[0]['lr'],
                'train/mean_generated_length': mean_generated_length.item(),
                'train/max_generated_length': max_generated_length.item(),
                'train/unsafe_num': unsafe_num.item(),
            }
        else:
            return {
                'train/actor_loss': actor_loss.item(),
                'train/reward_critic_loss': reward_critic_loss.item(),
                'train/cost_critic_loss': cost_critic_loss.item(),
                'train/lambda': self._lambda.item(),
                'train/episode_cost': episode_cost.item(),
                'train/reward': reward.item(),
                'train/cost': cost.item(),
                'train/reward_with_kl_penalty': reward_with_kl_penalty.item(),
                'train/reward_advantage': reward_advantage.item(),
                'train/reward_return': reward_return.item(),
                'train/reward_value': reward_value.item(),
                'train/cost_with_kl_penalty': cost_with_kl_penalty.item(),
                'train/cost_advantage': cost_advantage.item(),
                'train/cost_return': cost_return.item(),
                'train/cost_value': cost_value.item(),
                'train/kl_divergence': kl_divergence.item(),
                'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
                'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
                'train/cost_critic_lr': self.cost_critic_model.optimizer.param_groups[0]['lr'],
                'train/mean_generated_length': mean_generated_length.item(),
                'train/max_generated_length': max_generated_length.item(),
                'train/unsafe_num': unsafe_num.item(),
            }