from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register

from collections import defaultdict, Counter
import torch

from .ewma import RewardTrendMonitor
import re
# def extract_solution(solution_str, method="strict"):
#     assert method in ["strict", "flexible"]
#
#     if method == "strict":
#         # this also tests the formatting of the model
#         solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str)
#         if len(solutions) == 0:
#             final_answer = None
#         else:
#             # take the last solution
#             final_answer = solutions[-1].replace(",", "").replace("$", "")
#     elif method == "flexible":
#         answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
#         final_answer = None
#         if len(answer) == 0:
#             # no reward is there is no answer
#             pass
#         else:
#             invalid_str = ["", "."]
#             # find the last number that is not '.'
#             for final_answer in reversed(answer):
#                 if final_answer not in invalid_str:
#                     break
#     return final_answer
#
# def yes_or_no(response, ground_truth, format_score=0.0, score=1.0):
#     answer = extract_solution(solution_str=response)
#     if answer is None:
#         return 0
#     else:
#         if answer == ground_truth:
#             return score
#         else:
#             return format_score
#     pass

from verl.utils.reward_score.math import compute_score as yes_or_no
# from verl.utils.reward_score.gsm8k import compute_score as yes_or_no
@register("naive_length")
class NaiveLengthRewardManager:
    """The reward manager with group-level correctness and length statistics."""

    def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None:
        self.tokenizer = tokenizer
        self.num_examine = num_examine
        self.compute_score = compute_score or default_compute_score
        self.reward_fn_key = reward_fn_key
        self.trend_monitor3 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=3,
            threshold=0.001,
            min_steps=5
        )
        self.trend_monitor5 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=5,
            threshold=0.001,
            min_steps=5
        )
        self.trend_monitor10 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=10,
            threshold=0.001,
            min_steps=5
        )
        self.trend_monitor15 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=15,
            threshold=0.001,
            min_steps=5
        )
        self.trend_monitor20 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=20,
            threshold=0.001,
            min_steps=5
        )
        self.trend_monitor25 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=25,
            threshold=0.0001,
            min_steps=5
        )
        self.trend_monitor30 = RewardTrendMonitor(
            window_size=100,
            decay_rate=0.1,
            trend_window=30,
            threshold=0.0001,
            min_steps=5
        )

    def __call__(self, data: DataProto, n, return_dict=False):
        # 如果已有 rm_scores，直接返回
        if "rm_scores" in data.batch.keys():
            if return_dict:
                return {"reward_tensor": data.batch["rm_scores"]}
            else:
                return data.batch["rm_scores"]

        batch_size = len(data)
        # --- STEP 1: 先把所有 response_str 和 ground_truth 解码 / 提取出来 ---
        response_strs = []
        ground_truths = []
        valid_response_lengths = []
        for item in data:
            # prompt 部分不变，这里略去
            prompt_len = item.batch["attention_mask"][: item.batch["prompts"].shape[-1]].sum()
            resp_mask = item.batch["attention_mask"][item.batch["prompts"].shape[-1] :]
            valid_len = int(resp_mask.sum())
            resp_ids = item.batch["responses"][:valid_len]
            response_strs.append(self.tokenizer.decode(resp_ids, skip_special_tokens=True))
            valid_response_lengths.append(valid_len)
            ground_truths.append(item.non_tensor_batch["reward_model"]["ground_truth"])

        # --- STEP 2: 按照每 n 条分组，计算组内的各种统计指标 ---
        # 结果字典，key: data index, value: dict of stats
        group_stats = {}
        for g in range((batch_size + n - 1) // n):
            start = g * n
            end = min(start + n, batch_size)
            # 收集这一组的数据
            group_resps = response_strs[start:end]
            group_truths = ground_truths[start:end]
            group_lens = valid_response_lengths[start:end]

            # 调用外部函数判断正确与否
            correctness = [yes_or_no(r, t) for r, t in zip(group_resps, group_truths)]
            num_correct = sum(correctness)
            prop_correct = num_correct / len(group_resps)

            # 只有当至少有一个正确时，才计算后续长度统计；否则都置最大长度
            if num_correct > 0:
                correct_lens = [L for L, ok in zip(group_lens, correctness) if ok]
                longest = max(correct_lens)
                shortest = min(correct_lens)
                mean_len = sum(correct_lens) / len(correct_lens)
                # “密度最高点”采用众数
                mode_len = Counter(correct_lens).most_common(1)[0][0]
            else:
                longest = shortest = mean_len = mode_len = 3076 #若全错的话，设置为最大token长度

            # 把统计指标写入每条
            for idx in range(start, end):
                group_stats[idx] = {
                    "is_correct": correctness[idx - start],
                    "correct_proportion": prop_correct,
                    "max_correct_length": longest,
                    "min_correct_length": shortest,
                    "mode_correct_length": mode_len,
                    "mean_correct_length": mean_len,
                }

        # --- STEP 3: 逐条写入 non_tensor_batch 并按原逻辑计算分数 ---
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        correct_reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        length_reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)

        reward_extra_info = defaultdict(list)
        printed_cnt = defaultdict(int)

        for i, data_item in enumerate(data):
            # 把新统计字段注入 non_tensor_batch
            stats = group_stats[i]
            ntb = data_item.non_tensor_batch
            ntb["is_correct"] = stats["is_correct"]
            ntb["correct_proportion"] = stats["correct_proportion"]
            ntb["max_correct_length"] = stats["max_correct_length"]
            ntb["min_correct_length"] = stats["min_correct_length"]
            ntb["mode_correct_length"] = stats["mode_correct_length"]
            ntb["mean_correct_length"] = stats["mean_correct_length"]

            # 原有逻辑：拿到 data_source, ground_truth 等
            prompt_ids = data_item.batch["prompts"]
            prompt_len = prompt_ids.shape[-1]
            valid_prompt_len = int(data_item.batch["attention_mask"][:prompt_len].sum())
            valid_resp_len = int(data_item.batch["attention_mask"][prompt_len:].sum())

            prompt_str = self.tokenizer.decode(
                prompt_ids[-valid_prompt_len :], skip_special_tokens=True
            )
            response_str = response_strs[i]
            ground_truth = ntb["reward_model"]["ground_truth"]
            data_source = ntb[self.reward_fn_key]
            extra_info = ntb.get("extra_info", {})
            extra_info["num_turns"] = ntb.get("__num_turns__", None)

            # 调[object Object][object Object]_score，score 可以是 float 或 dict
            score, correct_score, length_score = self.compute_score(
                ntb["is_correct"],
                valid_resp_len,
                ntb["correct_proportion"],
                ntb["min_correct_length"],
                ntb["max_correct_length"]
            )
            if isinstance(score, dict):
                reward = score["score"]
                for k, v in score.items():
                    reward_extra_info[k].append(v)
            else:
                reward = score
                correct_reward = correct_score
                length_reward = length_score

            reward_tensor[i, valid_resp_len - 1] = reward
            correct_reward_tensor[i, valid_resp_len - 1] = correct_reward
            length_reward_tensor[i, valid_resp_len - 1] = length_reward

            # debugging 打印
            if printed_cnt[data_source] < self.num_examine:
                printed_cnt[data_source] += 1
                print("[prompt]", prompt_str)
                print("[response]", response_str)
                print("[ground_truth]", ground_truth)
                if isinstance(score, dict):
                    for k, v in score.items():
                        print(f"[{k}]", v)
                else:
                    print("[score]", score)

        correct_reward_mean = torch.mean(correct_reward_tensor.sum(-1)).detach().item()
        apply_length_reward3 = self.trend_monitor3.should_apply_length_reward(correct_reward_mean)
        apply_length_reward5 = self.trend_monitor5.should_apply_length_reward(correct_reward_mean)
        apply_length_reward10 = self.trend_monitor10.should_apply_length_reward(correct_reward_mean)
        apply_length_reward15 = self.trend_monitor15.should_apply_length_reward(correct_reward_mean)
        apply_length_reward20 = self.trend_monitor20.should_apply_length_reward(correct_reward_mean)
        apply_length_reward25 = self.trend_monitor25.should_apply_length_reward(correct_reward_mean)
        apply_length_reward30 = self.trend_monitor30.should_apply_length_reward(correct_reward_mean)
        if apply_length_reward3 and apply_length_reward5 and apply_length_reward10 and apply_length_reward15 and apply_length_reward20 and apply_length_reward25 and apply_length_reward30:
            pass
        else:
            reward_tensor = reward_tensor - length_reward_tensor
            length_reward_tensor.zero_()

        # no_length_reward = True
        # if no_length_reward == True:
        #     reward_tensor = reward_tensor - length_reward_tensor
        #     length_reward_tensor.zero_()


        if return_dict:
            return {
                "reward_tensor": reward_tensor,
                "correct_reward_tensor": correct_reward_tensor,
                "length_reward_tensor": length_reward_tensor,
                "reward_extra_info": reward_extra_info,
            }
        else:
            return reward_tensor
