# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from transformers import PreTrainedTokenizer

from ...protocol import DataProto
from ...utils.reward_score import math_compute_score, r1v_compute_score


class CustomRewardManager:
    def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str, is_pra = True, max_resp_len=None, overlong_buffer_cfg=None):
        self.tokenizer = tokenizer
        self.num_examine = num_examine
        if compute_score == "math":
            self.compute_score = math_compute_score
        elif compute_score == "r1v":
            self.compute_score = r1v_compute_score
        else:
            raise NotImplementedError()
        self.is_pra = is_pra

        self.overlong_buffer_cfg = overlong_buffer_cfg
        self.max_resp_len = max_resp_len

        if self.overlong_buffer_cfg is not None:
            assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"

    def __call__(self, data: DataProto, n:int) -> torch.Tensor:
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        already_print = 0
        acc_reward_tensor = torch.zeros(data.batch["responses"].shape[0], dtype=torch.float32)
        final_acc_reward_tensor = torch.zeros(data.batch["responses"].shape[0], dtype=torch.float32)
        format_reward_tensor = torch.zeros(data.batch["responses"].shape[0], dtype=torch.float32)
        uid2traj2number = {}
        # acc_reward_list = []
        # format_reward_list = []
        valid_response_length_list = []
        traj_list = []
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem

            prompt_ids = data_item.batch["prompts"]
            prompt_length = prompt_ids.shape[-1]

            valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)

            ground_truth = data_item.non_tensor_batch["ground_truth"]


            is_sigle_image = data_item.non_tensor_batch["image_number"] == 1
            score, acc_score, format_score, traj = self.compute_score(response_str, ground_truth, is_sigle_image)
            if data_item.non_tensor_batch["uid"] not in uid2traj2number:
                uid2traj2number[data_item.non_tensor_batch["uid"]]={}
            if traj not in uid2traj2number[data_item.non_tensor_batch["uid"]]:
                uid2traj2number[data_item.non_tensor_batch["uid"]][traj] = 0
            uid2traj2number[data_item.non_tensor_batch["uid"]][traj] = uid2traj2number[data_item.non_tensor_batch["uid"]][traj] + 1
            traj_list.append(traj)
            if self.overlong_buffer_cfg.enable_overlong_buffer:
                overlong_buffer_len = self.overlong_buffer_cfg.overlong_buffer_len
                expected_len = self.max_resp_len - overlong_buffer_len
                exceed_len = valid_response_length - expected_len
                overlong_penalty_factor = self.overlong_buffer_cfg.overlong_penalty_factor
                overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
                score += overlong_reward
            valid_response_length_list.append(valid_response_length)
            reward_tensor[i, valid_response_length - 1] = score
            acc_reward_tensor[i] = acc_score
            format_reward_tensor[i] = format_score

            if already_print < self.num_examine:
                already_print += 1
                print("[prompt]", prompt_str)
                print("[response]", response_str)
                print("[ground_truth]", ground_truth)
                if self.overlong_buffer_cfg.enable_overlong_buffer:
                    print("[overlong_reward]", overlong_reward)
                print("[score]", score, "acc:",acc_score, "format:",format_score)

        # diversity
        uid2traj2scale = {}
        # print('uid2traj2number',uid2traj2number)
        for uid, traj2number in uid2traj2number.items():
            for traj, number in traj2number.items():
                if uid not in uid2traj2scale:
                    uid2traj2scale[uid]={}
                if traj not in uid2traj2scale[uid]:
                    uid2traj2scale[uid][traj] = number
        
        for i in range(len(data)):
            traj = traj_list[i]
            valid_response_length = valid_response_length_list[i]
            uid = data[i].non_tensor_batch["uid"]
            final_acc_reward_tensor[i] = -1*acc_reward_tensor[i]*((uid2traj2scale[uid][traj]-1)/(n-1))*0.1 + acc_reward_tensor[i]
            reward_tensor[i, valid_response_length-1] = 0.5*final_acc_reward_tensor[i] +  0.5*format_reward_tensor[i]

        return reward_tensor, acc_reward_tensor, format_reward_tensor
