# Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Any

import torch
import torch.distributed as dist
import numpy as np
import pdb

from safe_rlhf.trainers import RLTrainer
from safe_rlhf.utils import (
    batch_retokenize,
    gather_log_probabilities,
    get_all_reduce_max,
    get_all_reduce_mean,
    masked_mean,
)


class PPOTrainer(RLTrainer):
    TRAINING_TYPE = 'ppo'

    @torch.no_grad()
    def post_rollout(
        self,
        prompt: torch.Tensor,
        sequence: torch.Tensor,
        attention_mask: torch.BoolTensor,
        Z = None
    ) -> dict[str, Any]:
        
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                sequence,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_seq = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_seq = sequence
            reward_attention_mask = attention_mask

        logits = self.actor_model(sequence, attention_mask=attention_mask).logits
        ref_logits = self.actor_reference_model(sequence, attention_mask=attention_mask).logits

        if self.args.reward_model_name_or_path_2 == "none":
            reward = self.reward_model(reward_seq, attention_mask=reward_attention_mask).end_scores

        else:
            reward_1 = self.reward_model_1(reward_seq, attention_mask=reward_attention_mask).end_scores
            reward_2 = self.reward_model_2(reward_seq, attention_mask=reward_attention_mask).end_scores
            if self.args.reward_type == 'linear':
                reward = reward_1 * self.args.reward_coeff_1 + reward_2 * self.args.reward_coeff_2
            elif self.args.reward_type == 'direct-probability':
                p_1 = torch.sigmoid(reward_1)
                p_2 = torch.sigmoid(reward_2)
                # F = lambda x : 1 / (1 + torch.exp(-(x - 0.5)))
                # F = lambda x : torch.exp(0.5)
                A = p_1 * self.args.reward_coeff_1 + p_2 * self.args.reward_coeff_2
                reward = torch.log(A / (1 - A))
            elif self.args.reward_type == 'relative-probability':
                batchsize = reward_1.size(0)
                random_index = np.random.randint(0, batchsize)
                selected_reward_1 = reward_1[random_index, :]
                selected_reward_2 = reward_2[random_index, :]
                extended_reward_1 = selected_reward_1.repeat(batchsize, 1)
                extended_reward_2 = selected_reward_2.repeat(batchsize, 1)
                A = torch.sigmoid(reward_1 - extended_reward_1) * self.args.reward_coeff_1 + torch.sigmoid(reward_2 - extended_reward_2) * self.args.reward_coeff_2
                reward = torch.log(A / (1 - A))
            elif self.args.reward_type == 'average-probability':
                batchsize = reward_1.size(0)
                A = []
                for index in range(batchsize):
                    selected_reward_1 = reward_1[index, :]
                    selected_reward_2 = reward_2[index, :]
                    extended_reward_1 = selected_reward_1.repeat(batchsize, 1)
                    extended_reward_2 = selected_reward_2.repeat(batchsize, 1)
                    A.append((torch.sigmoid(reward_1 - extended_reward_1) * self.args.reward_coeff_1 + torch.sigmoid(reward_2 - extended_reward_2) * self.args.reward_coeff_2).squeeze())
                B = torch.stack(A).mean(-2)
                reward = torch.log(B / (1 - B))
            elif self.args.reward_type == 'DPO-mix':
                max_ab = torch.max(reward_1 / self.args.kl_coeff,reward_2 / self.args.kl_coeff)
                # exp_reward_1 = torch.exp(reward_1 / self.args.kl_coeff)
                # exp_reward_2 = torch.exp(reward_2 / self.args.kl_coeff)
                reward = (max_ab + torch.log(torch.exp(reward_1 / self.args.kl_coeff - max_ab) * self.args.reward_coeff_1 + torch.exp(reward_2 / self.args.kl_coeff - max_ab) * self.args.reward_coeff_2)) * self.kl_coeff
                # pdb.set_trace()
            elif self.args.reward_type == 'estimate-mix':
                max_ab = torch.max(reward_1 / self.args.kl_coeff,reward_2 / self.args.kl_coeff)
                # exp_reward_1 = torch.exp(reward_1 / self.args.kl_coeff)
                # exp_reward_2 = torch.exp(reward_2 / self.args.kl_coeff)
                reward = (max_ab + torch.log(torch.exp(reward_1 / self.args.kl_coeff - max_ab) * self.args.reward_coeff_1 * Z[:,0].unsqueeze(-1) + torch.exp(reward_2 / self.args.kl_coeff - max_ab) * self.args.reward_coeff_2 * Z[:,1].unsqueeze(-1))) * self.kl_coeff
            elif self.args.reward_type == 'max':
                reward = torch.cat((reward_1 * self.args.reward_coeff_1,reward_2 * self.args.reward_coeff_2),dim=1).max(dim=1).values
                # pdb.set_trace()
            elif self.args.reward_type == 'relative-infer':
                sft_infer_tokenize_output = batch_retokenize(
                    Z,
                    src_tokenizer=self.tokenizer,
                    dest_tokenizer=self.reward_tokenizer,
                    skip_special_tokens=True,
                    device=self.args.device,
                )
                sft_infer_seq = sft_infer_tokenize_output['input_ids']
                sft_infer_attention_mask = sft_infer_tokenize_output['attention_mask']

                sft_reward_1 = self.reward_model_1(sft_infer_seq, attention_mask=sft_infer_attention_mask).end_scores
                sft_reward_2 = self.reward_model_2(sft_infer_seq, attention_mask=sft_infer_attention_mask).end_scores
                A = torch.sigmoid(reward_1 - sft_reward_1) * self.args.reward_coeff_1 + torch.sigmoid(reward_2 - sft_reward_2) * self.args.reward_coeff_2
                reward = torch.log(A / (1 - A))
                
        
        reward_values = self.reward_critic_model(sequence, attention_mask=attention_mask).scores

        reward = reward.squeeze(dim=-1)
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]

        log_probs = gather_log_probabilities(logits[:, :-1], sequence[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], sequence[:, 1:])
        info = {
            'prompt': prompt,
            'log_probs': log_probs,
            'ref_log_probs': ref_log_probs,
            'reward': reward,
            'reward_values': reward_values,
            'input_ids': sequence,
            'attention_mask': attention_mask,
        }

        if self.args.reward_model_name_or_path_2 != "none":
            info['reward_1'] = reward_1.squeeze(dim=-1)
            info['reward_2'] = reward_2.squeeze(dim=-1)
        else:
            info['reward_1'] = reward
            info['reward_2'] = reward
        return info

    @torch.no_grad()
    def eval_step(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.BoolTensor,
        start = None,
    ) -> dict[str, torch.Tensor]:
        if self.reward_tokenizer is not self.tokenizer:
            reward_tokenize_output = batch_retokenize(
                input_ids,
                src_tokenizer=self.tokenizer,
                dest_tokenizer=self.reward_tokenizer,
                skip_special_tokens=True,
                device=self.args.device,
            )
            reward_input_ids = reward_tokenize_output['input_ids']
            reward_attention_mask = reward_tokenize_output['attention_mask']
        else:
            reward_input_ids = input_ids
            reward_attention_mask = attention_mask

        
        mask = reward_attention_mask[:, 1:][:, start:]
        logits = self.actor_model(reward_input_ids, attention_mask=reward_attention_mask).logits
        ref_logits = self.actor_reference_model(reward_input_ids, attention_mask=reward_attention_mask).logits
        log_probs = gather_log_probabilities(logits[:, :-1], reward_input_ids[:, 1:])
        ref_log_probs = gather_log_probabilities(ref_logits[:, :-1], reward_input_ids[:, 1:])
        # pdb.set_trace()

        reward_1 = self.reward_model_1(
            reward_input_ids,
            attention_mask=reward_attention_mask,
        ).end_scores.squeeze(dim=-1)

        reward_2 = self.reward_model_2(
            reward_input_ids,
            attention_mask=reward_attention_mask,
        ).end_scores.squeeze(dim=-1)

        kl_divergence = ((log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1)

        return {
            'eval/reward_1': reward_1,
            'eval/reward_2': reward_2,
            'eval/log_probs':log_probs.sum(dim=-1),
            'eval/ref_log_probs':ref_log_probs.sum(dim=-1),
            'eval/kl': kl_divergence,
        }

    def add_kl_divergence_regularization(
        self,
        reward: torch.Tensor,  # size = (B,)
        prompt: torch.LongTensor,  # size = (B, S) # pylint: disable=unused-argument
        log_probs: torch.Tensor,  # size = (B, L)
        ref_log_probs: torch.Tensor,  # size = (B, L)
        sequence_mask: torch.BoolTensor,  # size = (B, L)
    ) -> torch.Tensor:  # size = (B, L)
        end_index = torch.cat([m.nonzero()[-1] for m in sequence_mask])  # size = (B,)

        # size = (B, L)
        kl_divergence_estimate = log_probs - ref_log_probs
        kl_penalty_rewards = -self.kl_coeff * kl_divergence_estimate
        rewards = torch.scatter_add(
            kl_penalty_rewards,
            dim=-1,
            index=end_index.unsqueeze(dim=-1),
            src=reward.to(kl_penalty_rewards.dtype).unsqueeze(dim=-1),
        )
        return torch.clamp(rewards, min=-self.clip_range_score, max=self.clip_range_score)

    def actor_loss_fn(
        self,
        log_probs: torch.Tensor,  # size = (B, L - S)
        old_log_probs: torch.Tensor,  # size = (B, L - S)
        advantages: torch.Tensor,  # size = (B, L - S)
        mask: torch.BoolTensor,  # size = (B, L - S)
    ) -> torch.Tensor:  # size = ()
        # size = (B, L - S)
        ratios = torch.exp(log_probs - old_log_probs)
        surrogate1 = advantages * ratios
        surrogate2 = advantages * torch.clamp(
            ratios,
            1.0 - self.clip_range_ratio,
            1.0 + self.clip_range_ratio,
        )
        surrogate = torch.minimum(surrogate1, surrogate2)
        return -masked_mean(surrogate, mask)  # size = ()

    def rl_step(self, rl_batch: dict[str, torch.Tensor]) -> dict[str, Any]:
        prompt = rl_batch['prompt']
        old_log_probs = rl_batch['log_probs']
        ref_log_probs = rl_batch['ref_log_probs']
        reward = rl_batch['reward']
        old_reward_values = rl_batch['reward_values']
        input_ids = rl_batch['input_ids']
        attention_mask = rl_batch['attention_mask']

        start = prompt.size(-1) - 1
        sequence_mask = attention_mask[:, 1:]

        with torch.no_grad():
            old_rewards = self.add_kl_divergence_regularization(
                reward,
                prompt,
                old_log_probs,
                ref_log_probs,
                sequence_mask,
            )
            reward_advantages, reward_returns = self.get_advantages_and_returns(
                old_reward_values,
                old_rewards,
                sequence_mask,
                start,
            )

        logits = self.actor_model(input_ids, attention_mask=attention_mask, use_cache=False).logits
        log_probs = gather_log_probabilities(logits[:, :-1], input_ids[:, 1:])
        actor_loss = self.actor_loss_fn(
            log_probs[:, start:],
            old_log_probs[:, start:],
            reward_advantages,
            sequence_mask[:, start:],
        )
        self.actor_model.backward(actor_loss)
        self.actor_model.step()

        reward_values = self.reward_critic_model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=False,
        ).scores
        reward_values = reward_values.squeeze(dim=-1)[:, :-1]
        reward_critic_loss = self.critic_loss_fn(
            reward_values[:, start:],
            old_reward_values[:, start:],
            reward_returns,
            sequence_mask[:, start:],
        )
        self.reward_critic_model.backward(reward_critic_loss)
        self.reward_critic_model.step()

        with torch.no_grad():
            mask = sequence_mask[:, start:]
            kl_divergence = ((old_log_probs - ref_log_probs)[:, start:] * mask).sum(dim=-1).mean()
            mean_generated_length = mask.sum(dim=-1).float().mean()
            max_generated_length = mask.sum(dim=-1).float().max()

            reward = reward.mean()
            reward_1 = rl_batch['reward_1'].mean()
            reward_2 = rl_batch['reward_2'].mean()
            reward_with_kl_penalty = (old_rewards[:, start:] * mask).sum(dim=-1).mean()
            reward_advantage = masked_mean(reward_advantages, mask)
            reward_return = masked_mean(reward_returns, mask)
            reward_value = masked_mean(reward_values[:, start:], mask)

            actor_loss = get_all_reduce_mean(actor_loss)
            reward_critic_loss = get_all_reduce_mean(reward_critic_loss)
            reward = get_all_reduce_mean(reward)
            reward_1 = get_all_reduce_mean(reward_1)
            reward_2 = get_all_reduce_mean(reward_2)
            reward_with_kl_penalty = get_all_reduce_mean(reward_with_kl_penalty)
            reward_advantage = get_all_reduce_mean(reward_advantage)
            reward_return = get_all_reduce_mean(reward_return)
            reward_value = get_all_reduce_mean(reward_value)
            kl_divergence = get_all_reduce_mean(kl_divergence)
            mean_generated_length = get_all_reduce_mean(mean_generated_length)
            max_generated_length = get_all_reduce_max(max_generated_length)

        dist.barrier()

        info = {
            'train/actor_loss': actor_loss.item(),
            'train/reward_critic_loss': reward_critic_loss.item(),
            'train/reward': reward.item(),
            'train/reward_1': reward_1.item(),
            'train/reward_2': reward_2.item(),
            'train/reward_with_kl_penalty': reward_with_kl_penalty.item(),
            'train/reward_advantage': reward_advantage.item(),
            'train/reward_return': reward_return.item(),
            'train/reward_value': reward_value.item(),
            'train/kl_divergence': kl_divergence.item(),
            'train/actor_lr': self.actor_model.optimizer.param_groups[0]['lr'],
            'train/reward_critic_lr': self.reward_critic_model.optimizer.param_groups[0]['lr'],
            'train/mean_generated_length': mean_generated_length.item(),
            'train/max_generated_length': max_generated_length.item(),
        }
        return info
