# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement a multiprocess PPORM
"""

import itertools
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple

import torch
import os
import torch.distributed
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from ...utils import torch_functional as VF
from ...protocol import DataProto
from ...utils.py_functional import append_to_dict
from ...utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs

from ...trainer.core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm, compute_mse_dpo_loss_rm
from ...utils.dataset import ImageProcessMixin
from ...utils.py_functional import sums_at_rightmost_from_ids, repeat_rightmost_values_to_clusters
from .base import BasePPORM
from .config import RewardConfig
from PIL import Image
import base64
from io import BytesIO
import numpy as np
import pydevd_pycharm

__all__ = ["DataParallelPRIMERewardModel"]


class DataParallelPRIMERewardModel(BasePPORM, ImageProcessMixin):
    def __init__(self, config: RewardConfig, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):
        super().__init__(config)
        self.rank = int(os.getenv("RANK", "0"))
        self.max_pixels = config.max_pixels
        self.min_pixels = config.min_pixels
        self.config = config
        self.reward_module = reward_module
        self.ref_module = ref_module
        self.reward_optimizer = reward_optimizer
        self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size
        if config.use_torch_compile:
            self.log_probs_from_logits = torch.compile(VF.log_probs_from_logits, dynamic=True)
        else:
            self.log_probs_from_logits = VF.log_probs_from_logits

    def _forward_micro_batch(self, micro_batch, prompt_length, temperature):
        input_ids = micro_batch["input_ids"]
        batch_size, seqlen = input_ids.shape
        attention_mask = micro_batch["attention_mask"]
        position_ids = micro_batch["position_ids"]
        if position_ids.dim() == 3:  # qwen2vl mrope
            position_ids = position_ids.transpose(0, 1)  # (bsz, 3, seqlen) -> (3, bsz, seqlen)

        num_actions = micro_batch["input_ids"].shape[-1] - prompt_length
        # assert num_actions == 19000

        multi_modal_inputs = {}
        if "multi_modal_inputs" in micro_batch:
            for key in micro_batch["multi_modal_inputs"][0].keys():
                multi_modal_inputs[key] = torch.cat(
                    [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
                )
        # assert self.config.padding_free

        if self.config.padding_free:
            input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)  # input_ids_rmpad (total_nnz, ...)
            input_ids_rmpad = input_ids_rmpad.transpose(0, 1)  # (1, total_nnz)

            # unpad the position_ids to align the rotary
            if position_ids.dim() == 3:
                position_ids_rmpad = (
                    index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
                    .transpose(0, 1)
                    .unsqueeze(1)
                )  # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
            else:
                position_ids_rmpad = index_first_axis(
                    rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
                ).transpose(0, 1)

            # pad and slice the inputs if sp > 1
            if self.ulysses_sequence_parallel_size > 1:
                input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
                    input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
                )

            output = self.reward_module(
                input_ids=input_ids_rmpad,
                attention_mask=None,
                position_ids=position_ids_rmpad,
                **multi_modal_inputs,
                use_cache=False,
            )
            # pydevd_pycharm.settrace('47.83.127.143', port=47508, stdoutToServer=True, stderrToServer=True)
            rm_output_logits = output.logits.squeeze(0)

            if self.ulysses_sequence_parallel_size > 1:
                rm_output_logits = gather_outputs_and_unpad(rm_output_logits, gather_dim=0, unpad_dim=0, padding_size=pad_size)
            rm_log_labels = pad_input(hidden_states=rm_output_logits, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)[:, -num_actions - 1 : -1]

        else:
            output = self.reward_module(
                input_ids=micro_batch["input_ids"],
                attention_mask=micro_batch["attention_mask"],
                position_ids=micro_batch["position_ids"],
                **multi_modal_inputs,
                use_cache=False,
            )

            rm_output_logits = output.logits
            rm_output_logits.div_(temperature)
            rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :], dim=-1)  # (batch_size, seq_length, vocab_size)
            rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)  # (batch, seq_length)

        # ref_log_labels = micro_batch["ref_log_probs"]
        #
        # ref_log_labels.to(rm_log_labels.dtype)
        # q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:]  # this is actually diff of q

        # # trim unnecessary logprobs here
        # for i in range(micro_batch["input_ids"].shape[0]):
        #     q[i, max_positions[i] :] = 0
        #
        # # reward computation does not need gradient. only q needs
        # with torch.no_grad():
        #     # generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model.
        #     lam = self.config.lam
        #     beta = self.config.beta_train
        #     if lam == 0.0:
        #         r = q * beta
        #     else:
        #         # reward coefficient takes no effect here
        #         acc = micro_batch["acc"]
        #         q_ = q * beta
        #         r = torch.zeros_like(q)
        #         lastgaelam = 0
        #         # change the last token and mask out all paddings to make this process easier if we rely on outcome reward to calculate V
        #         for i in range(q.shape[0]):
        #             if self.config.prime_use_gt:
        #                 q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum()
        #             q_[i, max_positions[i] :] = 0
        #
        #         for t in reversed(range(num_actions)):
        #             delta = q_[:, t]
        #             lastgaelam = delta + lam * lastgaelam
        #             r[:, t] = lastgaelam
        #
        #     token_level_score = torch.zeros_like(q)
        #
        #     if self.config.prime_granularity == "token":
        #         for i in range(micro_batch["input_ids"].shape[0]):
        #             token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1]
        #     elif self.config.prime_granularity == "whole":
        #         for i in range(micro_batch["input_ids"].shape[0]):
        #             token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]]
        #     else:
        #         raise NotImplementedError

        return rm_log_labels

    def _optimizer_step(self):
        if isinstance(self.reward_module, FSDP):
            grad_norm = self.reward_module.clip_grad_norm_(self.config.max_grad_norm)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(), max_norm=self.config.max_grad_norm)

        if not torch.isfinite(grad_norm):
            print("Gradient norm is not finite. Skip update.")
        else:
            self.reward_optimizer.step()

        self.reward_optimizer.zero_grad()
        return grad_norm

    def prime_norm(self, token_level_scores):
        if self.config.prime_norm == "batch_norm":
            reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])
            token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
        return token_level_scores

    def compute_rm_score(self, data: DataProto, processor=None) -> torch.Tensor:
        self.reward_module.eval()
        temperature = data.meta_info["temperature"]
        select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "ref_log_probs"]
        prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1]

        if "multi_modal_inputs" in data.non_tensor_batch.keys():
            if 'doc_id' in data.non_tensor_batch.keys():
                mm_inputs = []
                for result_strs in data.non_tensor_batch['multi_modal_data']:
                    if len(result_strs['data']) > 0:
                        img_data = [self.process_image(Image.open(BytesIO(base64.urlsafe_b64decode(result_str)))) for
                                    result_str in result_strs['data']]
                        img_input = dict(processor.image_processor(img_data, return_tensors='pt'))
                    else:
                        img_input = {'pixel_values': torch.tensor([], dtype=torch.float32),
                                     'image_grid_thw': torch.tensor([], dtype=torch.int64)}
                    mm_inputs.append(img_input)
                data.non_tensor_batch['multi_modal_inputs'] = np.array(mm_inputs, object)
            non_tensor_select_keys = ["multi_modal_inputs"]
        else:
            non_tensor_select_keys = []

        micro_batches = data.select(select_keys, non_tensor_select_keys).split(
            self.config.micro_batch_size_per_device_for_experience
        )

        q_lst = []
        for micro_batch in micro_batches:
            model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
            with torch.no_grad():
                q = self._forward_micro_batch(model_inputs, prompt_length, temperature)
            temp_0 = torch.zeros((q.shape[0], 1), dtype=q.dtype, device=q.device)
            temp = torch.cat((temp_0, q[:, :-1]), dim=-1)
            tq = q - temp
            q_lst.append(tq)
        q = torch.concat(q_lst, dim=0)

        return q

    def update_rm(self, data: DataProto, processor=None) -> Dict[str, Any]:
        # make sure we are in training mode
        self.reward_module.train()

        beta = self.config.beta_train
        temperature = data.meta_info["temperature"]

        select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "prompts", "ref_log_probs", "token_level_scores"]
        if "loss_mask" in data.batch.keys():
            select_keys.append("loss_mask")
            select_keys.append("turn_sequence_mask")
            select_keys.append("end_of_response_position_mask")

        for key in ["Q_bc", "acc_bc"]:
            if key in data.batch.keys():
                select_keys.append(key)

        if "multi_modal_inputs" in data.non_tensor_batch.keys():
            if 'doc_id' in data.non_tensor_batch.keys():
                mm_inputs = []
                for result_strs in data.non_tensor_batch['multi_modal_data']:
                    if len(result_strs['data']) > 0:
                        img_data = [self.process_image(Image.open(BytesIO(base64.urlsafe_b64decode(result_str)))) for
                                    result_str in result_strs['data']]
                        img_input = dict(processor.image_processor(img_data, return_tensors='pt'))
                    else:
                        img_input = {'pixel_values': torch.tensor([], dtype=torch.float32),
                                     'image_grid_thw': torch.tensor([], dtype=torch.int64)}
                    mm_inputs.append(img_input)
                data.non_tensor_batch['multi_modal_inputs'] = np.array(mm_inputs, object)
            non_tensor_select_keys = ["multi_modal_inputs"]
        else:
            non_tensor_select_keys = []

        # Split to make minibatch iterator for updating the actor
        # See PPO paper for details. https://arxiv.org/abs/1707.06347
        mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)

        # rm_scores_lst = []
        q_lst = []

        # epoch_rm_scores_lst = []
        epoch_q_lst = []

        metrics = defaultdict(list)
        for _ in range(self.config.ppo_epochs):
            for mini_batch in mini_batches:
                gradient_accumulation = (
                        self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
                )
                micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)

                for micro_batch in micro_batches:
                    model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
                    attention_mask = model_inputs["attention_mask"]

                    # label_mean = model_inputs["prime_label"].mean(dim=-1, keepdim=True)
                    # label_std = model_inputs["prime_label"].std(dim=-1, keepdim=True)


                    prompt_ids = model_inputs["prompts"]
                    prompt_length = prompt_ids.shape[-1]

                    response_mask = attention_mask[:, prompt_length:]
                    if 'end_of_response_position_mask' in data.batch:
                        response_mask = model_inputs['end_of_response_position_mask']
                    acc = model_inputs["token_level_scores"]

                    q = self._forward_micro_batch(model_inputs, prompt_length, temperature)

                    # epoch_rm_scores_lst.append(rm_score)
                    epoch_q_lst.append(q.detach() * model_inputs["loss_mask"])
                    # pydevd_pycharm.settrace('47.83.127.143', port=47508, stdoutToServer=True, stderrToServer=True)

                    acc_repeat = torch.zeros_like(q, device=q.device)
                    for i in range(q.shape[0]):
                        acc_repeat[i] = repeat_rightmost_values_to_clusters(
                            acc[i],
                            model_inputs["turn_sequence_mask"][i] + 1,
                        )

                    if self.config.loss_type == "ce":
                        dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=model_inputs["loss_mask"], beta=beta)
                    elif self.config.loss_type == "mse":
                        dpo_loss = compute_mse_dpo_loss_rm(q * model_inputs["loss_mask"], acc_repeat)
                    elif self.config.loss_type == "dpo":
                        # the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update.
                        dpo_loss = compute_detach_dpo_loss_rm(q, acc, Q_bc=model_inputs["Q_bc"], acc_bc=model_inputs["acc_bc"], response_mask=response_mask, beta=beta)
                    elif self.config.loss_type == "bon_acc":
                        # change the original distribution of each sample to BoN distribution, then update reward model
                        dpo_loss = compute_detach_dpo_loss_rm(
                            q,
                            acc,
                            Q_bc=model_inputs["Q_bc"],
                            acc_bc=model_inputs["acc_bc"],
                            response_mask=response_mask,
                            beta=beta,
                            bon_mode="bon_acc",
                        )
                    elif self.config.loss_type == "bon_rm":
                        dpo_loss = compute_detach_dpo_loss_rm(
                            q,
                            acc,
                            Q_bc=model_inputs["Q_bc"],
                            acc_bc=model_inputs["acc_bc"],
                            response_mask=response_mask,
                            beta=beta,
                            bon_mode="bon_rm",
                        )
                    else:
                        raise NotImplementedError

                    batch_metrics = {"reward_model/dpo_loss": dpo_loss.detach().item()}
                    loss = dpo_loss / gradient_accumulation

                    loss.backward()

                    append_to_dict(metrics, batch_metrics)

                grad_norm = self._optimizer_step()
                append_to_dict(metrics, {"reward_model/grad_norm": grad_norm.detach().item()})
            # epoch_rm_scores = torch.cat(epoch_rm_scores_lst, dim=0)
            epoch_q = torch.concat(epoch_q_lst, dim=0)
            # rm_scores_lst.append(epoch_rm_scores)
            q_lst.append(epoch_q)
        # rm_scores = torch.stack(rm_scores_lst, dim=0).mean(dim=0)
        q = torch.stack(q_lst, dim=0).mean(dim=0)

        # rm_scores = self.prime_norm(rm_scores)
        append_to_dict(metrics,
                       {
                           # "reward_model/reward": rm_scores.sum(dim=-1).mean().item(),
                           "reward_model/raw_reward": q.sum(dim=-1).mean().item(),
                       })

        return metrics