# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from transformers import PreTrainedTokenizer

from verl import DataProto
from verl.utils.reward_score import math_compute_score, r1v_compute_score, seg_compute_score, seg_strict_compute_score, vision_reasoner_compute_score, dr2seg_compute_score

import numpy as np
import re
import json

def response_length_reward_fn(
    length_lst,
    des_length_lst,
    score_lst,
    rollout_n,
    target_len=45,          # 🎯 核心目标长度
    gamma=0.05,             # 🎯 每 token 惩罚强度
    beta = 0.15,             # 🎯 des_length > response_length 的惩罚强度
    score_threshold=3.0, # 仅有格式得分和非重复得分，没找到目标
):
    """
    Absolute-length-based reward shaping.
    Strongly discourages long thinking chains.
    """

    length_arr = np.array(length_lst, dtype=float)
    des_length_arr = np.array(des_length_lst, dtype=float)
    score_arr = np.array(score_lst, dtype=float)

    total = len(length_arr)
    assert total % rollout_n == 0
    batch_size = total // rollout_n

    final_rewards = np.zeros_like(score_arr)

    for b in range(batch_size):
        start = b * rollout_n
        end = (b + 1) * rollout_n

        L = length_arr[start:end]
        L_des = des_length_arr[start:end]
        S = score_arr[start:end]

        # ========= ⭐ 条件跳过长度惩罚 =========
        # 若该 batch 的 rollout 得分整体很低
        if np.max(S) <= score_threshold:
            final_rewards[start:end] = 1.0
            continue
        # ========= 绝对长度惩罚（核心） =========
        # 超过 target_len + tolerance 才开始罚
        excess = np.maximum(0, L - target_len)
        length_penalty = -gamma * excess


        # ========= 描述长度约束（硬惩罚） =========
        # des_length > response_length → 强惩罚
        hard_penalty = np.zeros_like(L)
        hard_penalty[L < L_des] = -1.0 * beta

        # ========= 合并 =========
        merged = 1 + length_penalty + hard_penalty
        merged = np.clip(merged, 0, 1)

        final_rewards[start:end] = merged

    return final_rewards.tolist()




class CustomRewardManager:
    def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str):
        self.tokenizer = tokenizer
        self.num_examine = num_examine
        if compute_score == "math":
            self.compute_score = math_compute_score
        elif compute_score == "r1v":
            self.compute_score = r1v_compute_score
        elif compute_score == "seg":
            self.compute_score = seg_compute_score
        elif compute_score == "seg_strict":
            self.compute_score = seg_strict_compute_score
        elif compute_score == "vision_reasoner":
            self.compute_score = vision_reasoner_compute_score
        elif compute_score == "dr2seg":
            self.compute_score = dr2seg_compute_score
        else:
            raise NotImplementedError()

    def __call__(self, data: DataProto, rollout_n) -> torch.Tensor:
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        already_print = 0
        already_print_adjust_score = 0
        response_length_lst = []
        description_answers_length_lst = []
        score_lst = []
        valid_response_length_lst = []
        for i in range(len(data)):
            data_item = data[i]  # DataProtoItem

            prompt_ids = data_item.batch["prompts"]
            prompt_length = prompt_ids.shape[-1]
            valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]

            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
            response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
            ground_truth = data_item.non_tensor_batch["solution"]
            description_answers_str = data_item.non_tensor_batch["description_answers"]
            score, response_answers_str, description_answers_str = self.compute_score(response_str, description_answers_str, ground_truth)
            response_length_lst.append(len(self.tokenizer.encode(response_answers_str)))
            description_answers_length_lst.append(len(self.tokenizer.encode(description_answers_str)))
            score_lst.append(score)
            valid_response_length_lst.append(valid_response_length)
            if already_print < self.num_examine:
                already_print += 1
                print("[prompt]", prompt_str)
                print("[response]", response_str)
                print("[description response]", description_answers_str)
                print("[ground_truth]", ground_truth)
                print("[score]", score)
        length_constraint_score_weight_lst = response_length_reward_fn(response_length_lst, description_answers_length_lst, score_lst, rollout_n)

        for i, (score, score_weight, valid_response_length) in enumerate(zip(score_lst, length_constraint_score_weight_lst, valid_response_length_lst)):
            if already_print_adjust_score < self.num_examine:
                already_print_adjust_score += 1
                print("[adjust score]", score * score_weight)
            reward_tensor[i, valid_response_length - 1] = score * score_weight
        return reward_tensor