import json
import numpy as np
from utils import (
    think_boxed_reward_fn,
    think_boxed_reward_fn_base
)
import re

def extract_answer_mc(model_response: str):
    """
    从大模型的响应中解析出多选题答案：
    - 返回 'A'/'B'/'C'/'D' 之一，或 'uncertainty'
    - 解析失败返回 None
    """
    if not isinstance(model_response, str):
        return None
    text = model_response.strip()

    # 0.0 先处理极简形式：A / A. / A) / 只有一个选项字母
    m_simple = re.fullmatch(r'([A-D])[\.\)]?\s*', text, flags=re.IGNORECASE)
    if m_simple:
        return m_simple.group(1).upper()

    # 0. 预处理：\text{...} 展开
    text_clean = re.sub(r'\\text\{\s*([^{}]+?)\s*\}', r'\1', text)

    # 0.1 额外预处理：把 \boxed{\text{A. ...}} 变成 \boxed{\text{A}}
    # 这样后面的 boxed 兜底规则仍然能复用
    # 同时先单独识别 \boxed{\text{uncertainty}} / \boxed{uncertainty}
    # 若出现则直接作为候选
    boxed_uncertainty_patterns = [
        r'\\boxed\{\s*\\text\{\s*uncertainty\s*\}\s*\}',
        r'\\boxed\{\s*uncertainty\s*\}',
    ]
    boxed_letter_with_prefix_pattern = (
        r'\\boxed\{\s*\\text\{\s*([A-D])\s*[\.\)]?[^}]*\}\s*\}'
    )

    boxed_candidates = []

    # 先收集 boxed uncertainty
    for pat in boxed_uncertainty_patterns:
        if re.search(pat, text_clean, flags=re.IGNORECASE):
            boxed_candidates.append('uncertainty')

    # 把 \boxed{\text{A. ...}} 归一化为 \boxed{\text{A}}
    def _norm_boxed_letter(m):
        letter = m.group(1).upper()
        if letter in ['A', 'B', 'C', 'D']:
            return r'\boxed{\text{' + letter + r'}}'
        return m.group(0)

    text_clean = re.sub(
        boxed_letter_with_prefix_pattern,
        _norm_boxed_letter,
        text_clean,
        flags=re.IGNORECASE
    )

    # 1. 在含有 answer 的局部范围内解析
    answer_token_pattern = r'[*"_\']*\banswer\b[*"_\']*'
    answer_spans = list(re.finditer(answer_token_pattern, text_clean, flags=re.IGNORECASE))

    candidates = []

    for m in answer_spans:
        start = m.start()
        end = min(len(text_clean), m.end() + 300)
        segment = text_clean[start:end]

        # 1.1 uncertainty（包含反引号）
        unc_patterns = [
            r'[*"_\']*\banswer\b[*"_\']*\s*[:：]\s*["`\']?uncertainty["`\']?\b',
            r'\banswer\b[^\n]*?\\boxed\{\s*\\text\{\s*uncertainty\s*\}\s*\}',
            r'\banswer\b[^\n]*?\\boxed\{\s*uncertainty\s*\}',
        ]
        found_uncertainty = False
        for pat in unc_patterns:
            if re.search(pat, segment, flags=re.IGNORECASE):
                candidates.append('uncertainty')
                found_uncertainty = True
                break

        if found_uncertainty:
            continue

        # 1.2 A-D（包含反引号）
        choice_patterns = [
            # answer: C / "answer": "C" / **answer**: `C`
            r'[*"_\']*\banswer\b[*"_\']*\s*[:：]\s*["`\']?([A-D])["`\']?\b',

            # **Answer:** C. / Answer: C) 等
            r'\banswer\b[^\n]*?[:：]\s*["`\']?([A-D])["`\']?\b',

            # answer \boxed{\text{C}}
            r'\banswer\b[^\n]*?\\boxed\{\s*\\text\{\s*([A-D])\s*\}\s*\}',
            # answer \boxed{C}
            r'\banswer\b[^\n]*?\\boxed\{\s*([A-D])\s*\}',

            # Answer C.（无冒号）
            r'\banswer\b[^\n]*?\b([A-D])\b',
        ]

        for pat in choice_patterns:
            m_choice = re.search(pat, segment, flags=re.IGNORECASE)
            if m_choice:
                choice = m_choice.group(1).upper()
                if choice in ['A', 'B', 'C', 'D']:
                    candidates.append(choice)
                    break

    # 若 answer 片段中有答案，按“最后一个”规则
    if candidates:
        return candidates[-1]

    # 2. 如果没有显式 answer，但存在 boxed uncertainty / letter 前缀情况
    if boxed_candidates:
        # 若同时有 letter 兜底，后面会覆盖，这里先不 return

        # 注意：这里不立即返回，而先继续看有没有 boxed A-D，
        # 最终以“最后出现的”规则处理

        pass

    # 2. 兜底：最后一个 boxed A-D 或 uncertainty
    last_choice = None

    # 先扫描 boxed uncertainty / boxed letter
    boxed_patterns = [
        r'\\boxed\{\s*\\text\{\s*uncertainty\s*\}\s*\}',
        r'\\boxed\{\s*uncertainty\s*\}',
        r'\\boxed\{\s*\\text\{\s*([A-D])\s*\}\s*\}',
        r'\\boxed\{\s*([A-D])\s*[\.\)]\s*[^}]*\}',
        r'\\boxed\{\s*([A-D])\s*\}',
    ]

    for pat in boxed_patterns:
        for m in re.finditer(pat, text_clean, flags=re.IGNORECASE):
            if m.lastindex is None:
                # 匹配到的是 uncertainty（无捕获组）
                last_choice = 'uncertainty'
            else:
                ch = m.group(m.lastindex).upper()
                if ch in ['A', 'B', 'C', 'D']:
                    last_choice = ch

    if last_choice is not None:
        return last_choice

    return None


def think_boxed_reward_fn_mc(model_response, gt, fast=False):
    if "</think>" in model_response:
        model_answer = model_response.split("</think>")[-1]
    
    model_answer = extract_answer_mc(model_response)
    if model_answer is None:
        # print("---none tag----")
        # print(model_response[-50:])

        return {"formatted": False}, 0.0  # Cannot even parse anything.

    if model_answer=='uncertainty':
        return {"formatted": True}, None
    
    elif len(model_answer)==1:

        is_correct = (model_answer==gt)
            
        if is_correct:
            return {"formatted": True}, 1.0  # Correctness reward.
        else:
            return {
                "formatted": True
            }, 0.0  # Formatted but wrong answer; no format reward to avoid hacking.
    else:
        return {"formatted": False}, 0.0


def cal_metrics_whole_mc(fn, is_base=False, print_md=True):
    print(f"Loading predictions from {fn}")
    data = json.load(open(fn, "r"))

    tasks_scope = ["gpqa_diamond", "mmlu_redux2"]
    task_data = {}

    # 按任务分组
    for item in data:
        task_name = item["task_name"]
        if task_name not in tasks_scope:
            continue
        if task_name not in task_data:
            task_data[task_name] = []
        task_data[task_name].append(item)

    for task_name in task_data.keys():
        print(f"Task: {task_name}：{len(task_data[task_name])}")
        
    # 初始化存储结构
    results = {
        "per_task": {
            "correct_ratio": {},
            "error_ratio": {},
            "uncertain_ratio": {},
            "truth_score": {},
            "paq": {},
            "f1": {},
            "avg_len": {},
            "formatted": {},
        },
        "average": {}
    }

    all_correct_ratios = []
    all_error_ratios = []
    all_uncertain_ratios = []
    all_avg_lens = []
    all_formatted = []
    all_paq_scores = []
    all_f1_scores = []

    # 遍历每个任务进行计算
    for task_name in sorted(task_data.keys()):
        items = task_data[task_name]
        batch_lengths = []
        batch_formatted = []

        total_correct = total_error = total_uncertain = total_responses = 0

        for item in items:
            preds = item["model_output"]
            gts = item["gt"]
            lengths = item["output_lengths"]

            rewards = []
            fmt_vals = []

            for p, gt in zip(preds, gts):
                info, r = think_boxed_reward_fn_mc(p, gt, fast=False)

                # 处理 reward 结果
                if r is None:
                    adjusted_r = 0.0
                    total_uncertain += 1
                elif r >= 0.9:
                    adjusted_r = 1.0
                    total_correct += 1
                else:
                    adjusted_r = r
                    total_error += 1
                rewards.append(adjusted_r)
                total_responses += 1

                # 收集 formatted 信息
                fmt_vals.append(info.get("formatted", 0))
            batch_lengths.append(np.mean(lengths))
            batch_formatted.append(np.mean(fmt_vals))

            item["reward"] = rewards
            item["formatted_info"] = fmt_vals

        # 计算任务级指标
        correct_ratio = total_correct / total_responses if total_responses > 0 else 0.0
        error_ratio = total_error / total_responses if total_responses > 0 else 0.0
        uncertain_ratio = total_uncertain / total_responses if total_responses > 0 else 0.0
        truth_score = correct_ratio - error_ratio

        # 计算 PAQ 指标: PAQ = Correct/(Correct+Incorrect)
        # In this context, Incorrect = Error responses
        paq = correct_ratio / (correct_ratio + error_ratio) if (correct_ratio + error_ratio) > 0 else 0.0

        # 计算 F1 指标: F1 = 2 * (P * R)/(P + R) where P=PAQ, R=Correct_ratio
        f1 = 2 * (paq * correct_ratio) / (paq + correct_ratio) if (paq + correct_ratio) > 0 else 0.0

        avg_len = np.mean(batch_lengths)
        fmt_score = np.mean(batch_formatted) if batch_formatted else 0.0

        # 存储每任务结果
        results["per_task"]["correct_ratio"][task_name] = correct_ratio
        results["per_task"]["error_ratio"][task_name] = error_ratio
        results["per_task"]["uncertain_ratio"][task_name] = uncertain_ratio
        results["per_task"]["truth_score"][task_name] = truth_score
        results["per_task"]["paq"][task_name] = paq
        results["per_task"]["f1"][task_name] = f1
        results["per_task"]["avg_len"][task_name] = avg_len
        results["per_task"]["formatted"][task_name] = fmt_score

        # 累积用于平均
        all_correct_ratios.append(correct_ratio)
        all_error_ratios.append(error_ratio)
        all_uncertain_ratios.append(uncertain_ratio)
        all_avg_lens.append(avg_len)
        all_formatted.append(fmt_score)
        all_paq_scores.append(paq)
        all_f1_scores.append(f1)

    # 计算跨任务平均值
    n_tasks = len(results["per_task"]["truth_score"])
    if n_tasks > 0:
        results["average"] = {
            "correct_ratio": np.mean(all_correct_ratios),
            "error_ratio": np.mean(all_error_ratios),
            "uncertain_ratio": np.mean(all_uncertain_ratios),
            "truth_score": np.mean([cr - er for cr, er in zip(all_correct_ratios, all_error_ratios)]),
            "paq": np.mean(all_paq_scores),
            "f1": np.mean(all_f1_scores),
            "avg_len": np.mean(all_avg_lens),
            "formatted": np.mean(all_formatted),
        }

    # 打印为 Markdown 表格
    if print_md:
        print("\n### 📊 Evaluation Results (Per Task & Average)")

        # 替换原来的 md_rows 构建部分
        md_rows = [
            "| Task | Correct | Error | Uncert | TruthScore | PAQ | F1 | AvgLen | Fmt |",
            "|------|---------|-------|--------|------------|-----|----|--------|-----|"
        ]

        def to_percent(x):
            """Convert float [0,1] to percentage string with 2 decimals"""
            return f"{x * 100:.2f}"

        for task in sorted(results["per_task"]["truth_score"].keys()):
            row = (
                f"| {task} "
                f"| {to_percent(results['per_task']['correct_ratio'][task])} "
                f"| {to_percent(results['per_task']['error_ratio'][task])} "
                f"| {to_percent(results['per_task']['uncertain_ratio'][task])} "
                f"| {to_percent(results['per_task']['truth_score'][task])} "
                f"| {to_percent(results['per_task']['paq'][task])} "
                f"| {to_percent(results['per_task']['f1'][task])} "
                f"| {results['per_task']['avg_len'][task]:.1f} "
                f"| {to_percent(results['per_task']['formatted'][task])} |"
            )
            md_rows.append(row)

        # 平均行（同样使用百分比）
        avg = results["average"]
        md_rows.append(
            f"| **Average** "
            f"| **{to_percent(avg['correct_ratio'])}** "
            f"| **{to_percent(avg['error_ratio'])}** "
            f"| **{to_percent(avg['uncertain_ratio'])}** "
            f"| **{to_percent(avg['truth_score'])}** "
            f"| **{to_percent(avg['paq'])}** "
            f"| **{to_percent(avg['f1'])}** "
            f"| **{avg['avg_len']:.1f}** "
            f"| **{to_percent(avg['formatted'])}** |"
        )


        print("\n".join(md_rows))

    # 保存更新后的数据
    # json.dump(data, open(fn, "w"), indent=4, ensure_ascii=False)

    return results


def cal_metrics_whole(fn, is_base=False, print_md=True):
    print(f"Loading predictions from {fn}")
    data = json.load(open(fn, "r"))

    tasks_scope = ["aime", "amc", "math", "minerva", "olympiad_bench"]
    # tasks_scope = ["gpqa_diamond", "mmlu_redux2"]

    task_data = {}

    # 按任务分组
    for item in data:
        task_name = item["task_name"]
        if task_name not in tasks_scope:
            continue
        if task_name not in task_data:
            task_data[task_name] = []
        task_data[task_name].append(item)

    # 初始化存储结构
    results = {
        "per_task": {
            "correct_ratio": {},
            "error_ratio": {},
            "uncertain_ratio": {},
            "truth_score": {},
            "paq": {},
            "f1": {},
            "avg_len": {},
            "formatted": {},
        },
        "average": {}
    }

    all_correct_ratios = []
    all_error_ratios = []
    all_uncertain_ratios = []
    all_avg_lens = []
    all_formatted = []
    all_paq_scores = []
    all_f1_scores = []

    # 遍历每个任务进行计算
    for task_name in sorted(task_data.keys()):
        items = task_data[task_name]
        batch_lengths = []
        batch_formatted = []

        total_correct = total_error = total_uncertain = total_responses = 0

        for item in items:
            preds = item["model_output"]
            gts = item["gt"]
            lengths = item["output_lengths"]

            rewards = []
            fmt_vals = []

            for p, gt in zip(preds, gts):
                info, r = think_boxed_reward_fn_base(p, gt, fast=False) if is_base else think_boxed_reward_fn(p, gt, fast=False)

                # 处理 reward 结果
                if r is None:
                    adjusted_r = 0.0
                    total_uncertain += 1
                elif r >= 0.9:
                    adjusted_r = 1.0
                    total_correct += 1
                else:
                    adjusted_r = r
                    total_error += 1
                rewards.append(adjusted_r)
                total_responses += 1

                # 收集 formatted 信息
                fmt_vals.append(info.get("formatted", 0))
            batch_lengths.append(np.mean(lengths))
            batch_formatted.append(np.mean(fmt_vals))

            item["reward"] = rewards
            item["formatted_info"] = fmt_vals

        # 计算任务级指标
        correct_ratio = total_correct / total_responses if total_responses > 0 else 0.0
        error_ratio = total_error / total_responses if total_responses > 0 else 0.0
        uncertain_ratio = total_uncertain / total_responses if total_responses > 0 else 0.0
        truth_score = correct_ratio - error_ratio

        # 计算 PAQ 指标: PAQ = Correct/(Correct+Incorrect)
        # In this context, Incorrect = Error responses
        paq = correct_ratio / (correct_ratio + error_ratio) if (correct_ratio + error_ratio) > 0 else 0.0

        # 计算 F1 指标: F1 = 2 * (P * R)/(P + R) where P=PAQ, R=Correct_ratio
        f1 = 2 * (paq * correct_ratio) / (paq + correct_ratio) if (paq + correct_ratio) > 0 else 0.0

        avg_len = np.mean(batch_lengths)
        fmt_score = np.mean(batch_formatted) if batch_formatted else 0.0

        # 存储每任务结果
        results["per_task"]["correct_ratio"][task_name] = correct_ratio
        results["per_task"]["error_ratio"][task_name] = error_ratio
        results["per_task"]["uncertain_ratio"][task_name] = uncertain_ratio
        results["per_task"]["truth_score"][task_name] = truth_score
        results["per_task"]["paq"][task_name] = paq
        results["per_task"]["f1"][task_name] = f1
        results["per_task"]["avg_len"][task_name] = avg_len
        results["per_task"]["formatted"][task_name] = fmt_score

        # 累积用于平均
        all_correct_ratios.append(correct_ratio)
        all_error_ratios.append(error_ratio)
        all_uncertain_ratios.append(uncertain_ratio)
        all_avg_lens.append(avg_len)
        all_formatted.append(fmt_score)
        all_paq_scores.append(paq)
        all_f1_scores.append(f1)

    # 计算跨任务平均值
    n_tasks = len(results["per_task"]["truth_score"])
    if n_tasks > 0:
        results["average"] = {
            "correct_ratio": np.mean(all_correct_ratios),
            "error_ratio": np.mean(all_error_ratios),
            "uncertain_ratio": np.mean(all_uncertain_ratios),
            "truth_score": np.mean([cr - er for cr, er in zip(all_correct_ratios, all_error_ratios)]),
            "paq": np.mean(all_paq_scores),
            "f1": np.mean(all_f1_scores),
            "avg_len": np.mean(all_avg_lens),
            "formatted": np.mean(all_formatted),
        }

    # 打印为 Markdown 表格
    if print_md:
        print("\n### 📊 Evaluation Results (Per Task & Average)")

        # 替换原来的 md_rows 构建部分
        md_rows = [
            "| Task | Correct | Error | Uncert | TruthScore | PAQ | F1 | AvgLen | Fmt |",
            "|------|---------|-------|--------|------------|-----|----|--------|-----|"
        ]

        def to_percent(x):
            """Convert float [0,1] to percentage string with 2 decimals"""
            return f"{x * 100:.2f}"

        for task in sorted(results["per_task"]["truth_score"].keys()):
            row = (
                f"| {task} "
                f"| {to_percent(results['per_task']['correct_ratio'][task])} "
                f"| {to_percent(results['per_task']['error_ratio'][task])} "
                f"| {to_percent(results['per_task']['uncertain_ratio'][task])} "
                f"| {to_percent(results['per_task']['truth_score'][task])} "
                f"| {to_percent(results['per_task']['paq'][task])} "
                f"| {to_percent(results['per_task']['f1'][task])} "
                f"| {results['per_task']['avg_len'][task]:.1f} "
                f"| {to_percent(results['per_task']['formatted'][task])} |"
            )
            md_rows.append(row)

        # 平均行（同样使用百分比）
        avg = results["average"]
        md_rows.append(
            f"| **Average** "
            f"| **{to_percent(avg['correct_ratio'])}** "
            f"| **{to_percent(avg['error_ratio'])}** "
            f"| **{to_percent(avg['uncertain_ratio'])}** "
            f"| **{to_percent(avg['truth_score'])}** "
            f"| **{to_percent(avg['paq'])}** "
            f"| **{to_percent(avg['f1'])}** "
            f"| **{avg['avg_len']:.1f}** "
            f"| **{to_percent(avg['formatted'])}** |"
        )


        print("\n".join(md_rows))


    return results


