
from __future__ import annotations

import argparse
from collections import deque
from typing import Any

import deepspeed
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase

from safe_rlhf.models import AutoModelForScore, load_pretrained_models
from safe_rlhf.trainers import RLTrainer
from safe_rlhf.utils import (
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    get_all_reduce_mean,
    is_main_process,
    is_same_tokenizer,
    masked_mean,
)


class PPOCALagTrainer(RLTrainer):
    TRAINING_TYPE = 'ppo_ca_lag'

    cost_model: deepspeed.DeepSpeedEngine
    cost_critic_model: deepspeed.DeepSpeedEngine

    cost_tokenizer: PreTrainedTokenizerBase
    cost_critic_tokenizer: PreTrainedTokenizerBase

    def __init__(
        self,
        args: argparse.Namespace,
        ds_train_config: dict[str, Any],
        ds_eval_config: dict[str, Any],
    ) -> None:
        super().__init__(args=args, ds_train_config=ds_train_config, ds_eval_config=ds_eval_config)

        # Lagrange multiplier
        self.log_lambda = torch.nn.Parameter(
            torch.tensor(np.log(self.args.lambda_init), device=self.args.device),
            requires_grad=True,
        )
        self.log_lambda_max = np.log(self.args.lambda_max) if self.args.lambda_max else None
        self.log_lambda_optimizer = torch.optim.SGD([self.log_lambda], lr=self.args.lambda_lr)
        self.lambda_update_delay_steps = self.args.lambda_update_delay_steps
        self.episode_costs = deque(maxlen=self.args.episode_cost_window_size)
        self.threshold = self.args.threshold
        self.compensate = self.args.compensate

    def init_models(self) -> None:
        super().init_models()
        self.cost_model, self.cost_tokenizer = load_pretrained_models(
            self.args.cost_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='right',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'cost',
                'do_normalize': self.args.normalize_cost,
            },
        )
        self.cost_model.set_normalize(self.args.normalize_cost)

        if self.args.cost_critic_model_name_or_path is None:
            self.args.cost_critic_model_name_or_path = self.args.cost_model_name_or_path
        self.cost_critic_model, self.cost_critic_tokenizer = load_pretrained_models(
            self.args.cost_critic_model_name_or_path,
            model_max_length=self.args.max_length,
            auto_model_type=AutoModelForScore,
            padding_side='left',
            trust_remote_code=self.args.trust_remote_code,
            auto_model_kwargs={
                'score_type': 'critic',
                'do_normalize': False,
            },
        )
        self.cost_critic_model.set_normalize(False)

        if is_same_tokenizer(self.tokenizer, self.cost_tokenizer):
            self.cost_tokenizer = self.tokenizer
        if not is_same_tokenizer(self.tokenizer, self.cost_critic_tokenizer):
            raise ValueError(
                (
                    'Cost critic tokenizer must be the same as actor tokenizer. '
                    'Expected {0.__module__}.{0.__qualname__}(vocab_size={1}), '
                    'but got {2.__module__}.{2.__qualname__}(vocab_size={3}). '
                    'Please consider pass `--cost_critic_model_name_or_path` from the command line.'
                ).format(
                    type(self.tokenizer),
                    len(self.tokenizer),
                    type(self.cost_critic_tokenizer),
                    len(self.cost_critic_tokenizer),
                ),
            )
        self.cost_critic_tokenizer = self.tokenizer

    def init_engines(self) -> None:
        super().init_engines()

        self.cost_critic_model = self._init_train_engine(
            model=self.cost_critic_model,
            weight_decay=self.args.critic_weight_decay,
            lr=self.args.critic_lr,
            lr_scheduler_type=self.args.critic_lr_scheduler_type,
            lr_warmup_ratio=self.args.critic_lr_warmup_ratio,
            total_training_steps=self.args.total_training_steps,
            ds_config=self.ds_train_config,
        )

        self.cost_model = self._init_eval_engine(
            model=self.cost_model,
            ds_config=self.ds_eval_config,
        )
        self.cost_model.eval()

        if self.args.critic_gradient_checkpointing:
            self.cost_critic_model.gradient_checkpointing_enable()

    def set_train(self, mode: bool = True) -> None:
        super().set_train(mode=mode)
        if mode:
            self.cost_critic_model.train()
        else:
            self.cost_critic_model.eval()

    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, Any]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_seq = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_seq = sequence
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_seq = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_seq = sequence
            cost_attention_mask = attention_mask

        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits

        reward = self.reward_model(reward_seq, attention_mask=reward_attention_mask).end_scores
        cost = self.cost_model(cost_seq, attention_mask=cost_attention_mask).end_scores
        reward_values = self.reward_critic_model(sequence, attention_mask=attention_mask).scores
        cost_values = self.cost_critic_model(sequence, attention_mask=attention_mask).scores

        reward_fg = self.reward_model(reward_seq, attention_mask=reward_attention_mask).scores
        reward_fg = reward_fg.squeeze(dim=-1)

        cost_fg = self.cost_model(reward_seq, attention_mask=reward_attention_mask).scores
        cost_fg = cost_fg.squeeze(dim=-1)

        reward = reward.squeeze(dim=-1)
        cost = cost.squeeze(dim=-1)
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]

        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])

        self.episode_costs.extend(cost.tolist())

        return {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'reward': reward,
            'cost': cost,
            'reward_values': reward_values,
            'cost_values': cost_values,
            'input_ids': sequence,
            'attention_mask': attention_mask,
            'reward_fg': reward_fg,
            'cost_fg': cost_fg,
        }

    @torch.no_grad()
    def eval_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.BoolTensor,
    ) -> dict[str, torch.Tensor]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_input_ids = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_input_ids = input_ids
            reward_attention_mask = attention_mask

        if self.cost_tokenizer is not self.tokenizer:
            cost_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.cost_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            cost_input_ids = cost_tokenize_output['input_ids']
            cost_attention_mask = cost_tokenize_output['attention_mask']
        else:
            cost_input_ids = input_ids
            cost_attention_mask = attention_mask

        reward = self.reward_model(
            reward_input_ids,
            attention_mask=reward_attention_mask,
        ).end_scores.squeeze(dim=-1)
        cost = self.cost_model(
            cost_input_ids,
            attention_mask=cost_attention_mask,
        ).end_scores.squeeze(dim=-1)
        return {
            'eval/reward': reward,
            'eval/cost': cost,
        }

    def add_kl_divergence_regularization(
        self,
        reward: torch.Tensor,  # size = (B,)
        cost: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> tuple[torch.Tensor, torch.Tensor]:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate
        rewards = torch.scatter_add(
            0.5 * kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=reward.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        costs = torch.scatter_add(
            -0.5 * kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=cost.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        return (
            torch.clamp(rewards, min=-self.clip_range_score, max=self.clip_range_score),
            torch.clamp(costs, min=-self.clip_range_score, max=self.clip_range_score),
        )

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        old_log_probs: torch.Tensor,  # size = (B, L - S)
        reward_advantages: torch.Tensor,  # size = (B, L - S)
        cost_advantages: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        multiplier = self.log_lambda.exp().item()

        # size = (B, L - S)
        advantages = (reward_advantages - multiplier * cost_advantages) / (1.0 + multiplier)
        ratios = torch.exp(log_probs - old_log_probs)
        surrogate1 = advantages * ratios
        surrogate2 = advantages * torch.clamp(
            ratios,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        surrogate = torch.minimum(surrogate1, surrogate2)
        return -masked_mean(surrogate, mask)  # size = ()

    # pylint: disable-next=too-many-locals
    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        episode_cost = torch.tensor(self.episode_costs).mean().to(self.args.device)

        dist.reduce(episode_cost, dst=0, op=dist.ReduceOp.AVG)

        if is_main_process() and self.global_step >= self.lambda_update_delay_steps:
            lambda_loss = -(episode_cost - self.threshold) * self.log_lambda.exp()
            self.log_lambda_optimizer.zero_grad()
            lambda_loss.backward()
            self.log_lambda_optimizer.step()
            if self.log_lambda_max is not None:
                with torch.no_grad():
                    self.log_lambda.clamp_(max=self.log_lambda_max)

        dist.broadcast(self.log_lambda, src=0)

        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        reward = rl_batch['reward']
        cost = rl_batch['cost']
        old_reward_values = rl_batch['reward_values']
        old_cost_values = rl_batch['cost_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        reward_fg = rl_batch['reward_fg']
        cost_fg = rl_batch['cost_fg']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():

            old_rewards, old_costs = self.add_kl_divergence_regularization_fg(
                reward_fg,
                cost_fg,
                prompt,
                old_log_probs,
                ref_log_probs,
                sequence_mask,
                start,
            )


            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )
            cost_advantages, cost_returns = self.get_advantages_and_returns(
                old_cost_values,
                old_costs,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            cost_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        cost_values = self.cost_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        cost_values = cost_values.squeeze(dim=-1)[:, :-1]
        cost_critic_loss = self.critic_loss_fn(
            cost_values[:, start:],
            old_cost_values[:, start:],
            cost_returns,
            sequence_mask[:, start:],
        )
        self.cost_critic_model.backward(cost_critic_loss)
        self.cost_critic_model.step()

        with torch.no_grad():
            mask = sequence_mask[:, start:]
            kl_divergence = ((old_log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1).mean()
            mean_generated_length = mask.sum(dim=-1).float().mean()
            max_generated_length = mask.sum(dim=-1).float().max()

            reward = reward.mean()
            cost = cost.mean()

            reward_with_kl_penalty = (old_rewards[:, start:] * mask).sum(dim=-1).mean()
            reward_advantage = masked_mean(reward_advantages, mask)
            reward_return = masked_mean(reward_returns, mask)
            reward_value = masked_mean(reward_values[:, start:], mask)
            cost_with_kl_penalty = (old_costs[:, start:] * mask).sum(dim=-1).mean()
            cost_advantage = masked_mean(cost_advantages, mask)
            cost_return = masked_mean(cost_returns, mask)
            cost_value = masked_mean(cost_values[:, start:], mask)

            actor_loss = get_all_reduce_mean(actor_loss)
            reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
            cost_critic_loss = get_all_reduce_mean(cost_critic_loss)
            reward = get_all_reduce_mean(reward)
            cost = get_all_reduce_mean(cost)
            reward_with_kl_penalty = get_all_reduce_mean(reward_with_kl_penalty)
            reward_advantage = get_all_reduce_mean(reward_advantage)
            reward_return = get_all_reduce_mean(reward_return)
            reward_value = get_all_reduce_mean(reward_value)
            cost_with_kl_penalty = get_all_reduce_mean(cost_with_kl_penalty)
            cost_advantage = get_all_reduce_mean(cost_advantage)
            cost_return = get_all_reduce_mean(cost_return)
            cost_value = get_all_reduce_mean(cost_value)
            kl_divergence = get_all_reduce_mean(kl_divergence)
            mean_generated_length = get_all_reduce_mean(mean_generated_length)
            max_generated_length = get_all_reduce_max(max_generated_length)

        dist.barrier()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/cost_critic_loss': cost_critic_loss.item(),
            'train/lambda': self.log_lambda.exp().item(),
            'train/episode_cost': episode_cost.item(),
            'train/reward': reward.item(),
            'train/cost': cost.item(),
            'train/reward_with_kl_penalty': reward_with_kl_penalty.item(),
            'train/reward_advantage': reward_advantage.item(),
            'train/reward_return': reward_return.item(),
            'train/reward_value': reward_value.item(),
            'train/cost_with_kl_penalty': cost_with_kl_penalty.item(),
            'train/cost_advantage': cost_advantage.item(),
            'train/cost_return': cost_return.item(),
            'train/cost_value': cost_value.item(),
            'train/kl_divergence': kl_divergence.item(),
            'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
            'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
            'train/cost_critic_lr': self.cost_critic_model.optimizer.param_groups[0]['lr'],
            'train/mean_generated_length': mean_generated_length.item(),
            'train/max_generated_length': max_generated_length.item(),
        }



    def add_kl_divergence_regularization_fg(
        self,
        reward_fg: torch.Tensor,  # size = (B,)
        cost_fg: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
        start: int = 0,
    ) -> tuple[torch.Tensor, torch.Tensor]:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate


        reward_fg = reward_fg[:,:-1]
        reward_fg_diff = torch.zeros_like(reward_fg)
        reward_fg_diff[:, 1:] = reward_fg[:, 1:] - reward_fg[:, :-1]

        cost_fg = cost_fg[:,:-1]
        cost_fg_diff = torch.zeros_like(cost_fg)
        cost_fg_diff[:, 1:] = cost_fg[:, 1:] - cost_fg[:, :-1]


        if self.compensate==1:
            num_tokens = end_index - start
            reward_compensate = reward_fg[:, start-1]/num_tokens #(B,)
            cost_compensate = cost_fg[:, start-1] /num_tokens    #(B,)
            reward_compensate = reward_compensate.unsqueeze(-1).repeat(1, reward_fg_diff.size(-1))
            cost_compensate = cost_compensate.unsqueeze(-1).repeat(1, cost_fg_diff.size(-1))

            reward_fg_diff = kl_penalty_rewards + reward_fg_diff + reward_compensate
            cost_fg_diff = kl_penalty_rewards + cost_fg_diff + cost_compensate
        else:
            # print(8888888)
            # print(kl_penalty_rewards)
            # print(999999)
            # print(reward_fg_diff)
            reward_fg_diff = kl_penalty_rewards +reward_fg_diff
            cost_fg_diff = kl_penalty_rewards +cost_fg_diff



        return (
            torch.clamp(reward_fg_diff, min=-self.clip_range_score, max=self.clip_range_score),
            torch.clamp(cost_fg_diff, min=-self.clip_range_score, max=self.clip_range_score),
        )

    def add_kl_divergence_regularization_fg_step(
        self,
        reward_fg: torch.Tensor,  # size = (B,)
        cost_fg: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
        start: int = 0,
        step: int = 10,
    ) -> tuple[torch.Tensor, torch.Tensor]:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate

        reward_fg = reward_fg[:,:-1]
        reward_fg_diff = torch.zeros_like(reward_fg)
        shifted_reward_fg = torch.zeros_like(reward_fg)
        shifted_reward_fg[:, step:] = reward_fg[:, :-step]  # 向右平移step位
        mask = torch.zeros_like(reward_fg, dtype=torch.bool)
        mask[:, step-1::step] = True
        reward_fg_diff[:, step-1:] = torch.where(mask[:, step-1:], reward_fg[:, step-1:] - shifted_reward_fg[:, step-1:], torch.tensor(0.0))


        cost_fg = cost_fg[:,:-1]
        cost_fg_diff = torch.zeros_like(cost_fg)
        shifted_cost_fg = torch.zeros_like(cost_fg)
        shifted_cost_fg[:, step:] = cost_fg[:, :-step]  # 向右平移step位
        mask = torch.zeros_like(cost_fg, dtype=torch.bool)
        mask[:, step-1::step] = True
        cost_fg_diff[:, step-1:] = torch.where(mask[:, 1:], cost_fg[:, step-1:] - shifted_cost_fg[:, step-1:], torch.tensor(0.0))

        if self.compensate==1:
            num_tokens = end_index - start
            reward_compensate = reward_fg[:, start-1]/(num_tokens/3) #(B,)
            cost_compensate = cost_fg[:, start-1] /(num_tokens/3)    #(B,)
            reward_compensate = reward_compensate.unsqueeze(-1).repeat(1, reward_fg_diff.size(-1))
            cost_compensate = cost_compensate.unsqueeze(-1).repeat(1, cost_fg_diff.size(-1))

            reward_fg_diff[:, step-1:] = torch.where(mask[:, 1:], reward_fg_diff[:, step-1:] + reward_compensate[:, step-1:], torch.tensor(0.0))
            cost_fg_diff[:, step-1:] = torch.where(mask[:, 1:], cost_fg_diff[:, step-1:] + cost_compensate[:, step-1:], torch.tensor(0.0))

        reward_fg_diff = kl_penalty_rewards +reward_fg_diff
        cost_fg_diff = kl_penalty_rewards +cost_fg_diff

        return (
            torch.clamp(reward_fg_diff, min=-self.clip_range_score, max=self.clip_range_score),
            torch.clamp(cost_fg_diff, min=-self.clip_range_score, max=self.clip_range_score),
        )


