import matplotlib.pyplot as plt
import math
from typing import List
import os
import json
from dataclasses import dataclass
from functools import total_ordering
from grader import math_equal
from parser import extract_answer
import ipdb
from scipy.special import comb  # 高效计算组合数

def pass_at_k(num_correct: int, total: int = 1024, k: int = 1) -> float:
    if num_correct == 0:
        return 0.0
    if num_correct >= total:
        return 1.0
    if k > total:
        return 1.0
    return 1.0 - comb(total - num_correct, k) / comb(total, k)

def is_correct(output, ground_truth):
    """
    Check if the output is correct based on the ground truth.
    The output can be a string or a list of strings.
    """
    assert isinstance(output, str) and isinstance(ground_truth, str), \
        f"Output and ground truth must be strings, got {type(output)} and {type(ground_truth)}"
    final_answer = extract_answer(output, "aime24")

    return math_equal(final_answer, ground_truth)

@total_ordering
@dataclass
class Metrics:
    question_id: int
    correct_count_0shot: int = 0
    correct_count_4shot: int = 0
    correct_count_8shot: int = 0
    content: str = ""
    ground_truth: str = ""

    def __lt__(self, other):
        if not isinstance(other, Metrics):
            return NotImplemented
        return self.correct_count_0shot > other.correct_count_0shot

def load_and_count_from_dir(metrics_list, dir_path, shot_type, special_qids=[28]):
    assert shot_type in ["0shot", "4shot", "8shot"]
    count_field = f"correct_count_{shot_type}"

    dir_path = os.path.join(dir_path, "eval_results/global_step_0/aime24")

    for filename in os.listdir(dir_path):
        if not filename.endswith(".jsonl"):
            continue
        file_path = os.path.join(dir_path, filename)

        # 检查是否包含 "metrics" 字段
        if "metrics" in filename:
            continue
        print(f"Processing file: {filename}")
        # 读取有效文件中的所有 items
        # ipdb.set_trace()  # Debugging point to inspect final_answer and ground_truth

        with open(file_path, "r") as f:
            for line in f:
                data = json.loads(line)

                qid = data["idx"]
                question = data["question"]
                outputs = data["code"]
                ground_truth = data["gt"]
                if not isinstance(outputs, list) or len(outputs) != 1024:
                    assert False,\
                     f"Invalid outputs for question {qid} in file {filename}, expected list of length 1024, but got {len(outputs)}"

                if special_qids is not None and qid not in special_qids:
                    continue

                # 初始化基本信息
                metrics = metrics_list[qid]
                if metrics.content == "":
                    metrics.question_id = qid
                    metrics.content = question
                    metrics.ground_truth = ground_truth

                # 统计回答正确的次数
                correct_count = 0;
                correct_samples = []
                wrong_samples = []
                for output in outputs:
                    if is_correct(output, ground_truth):
                        correct_count += 1
                        if special_qids is not None and len(correct_samples) < 3:
                            correct_samples.append(output)
                    else:
                        if special_qids is not None and len(wrong_samples) < 3:
                            wrong_samples.append(output)
                setattr(metrics, count_field, correct_count)
                if special_qids is not None:
                    import pprint
                    pp = pprint.PrettyPrinter(indent=4)
                    print(f"QID: {qid:02d}, Question: {question}, shot: {shot_type} correct count: {correct_count}, "
                          f"ground truth: {ground_truth}")
                    print("Here are some correct samples:")
                    for i, sample in enumerate(correct_samples):
                        print(f"Sample {i+1}:")
                        pp.pprint(sample)
                    if len(wrong_samples) > 0:
                        print("Here are some wrong samples:")
                        for i, sample in enumerate(wrong_samples):
                            print(f"Sample {i+1}:")
                            pp.pprint(sample)


def get_aime24_jsonl(metrics_list: List[Metrics]):
    """
    获取 AIME24 的 JSONL 格式数据，包含每个问题的 ID、内容、ground truth 
    """
    data = []
    for m in metrics_list:
        data.append({
            "question_id": m.question_id,
            "content": m.content,
            "ground_truth": m.ground_truth,
            "correct_count_0shot": m.correct_count_0shot,
            "correct_count_4shot": m.correct_count_4shot,
            "correct_count_8shot": m.correct_count_8shot,
        })
    # 保存为 JSONL 文件
    with open("FIGS/aime24_questions.jsonl", "w") as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")
    print("AIME24 JSONL 数据已保存至 FIGS/aime24_questions.jsonl")
    return


def plot_pass_at_k(metrics_list: List[Metrics], k: int, save_path: str = "pass_at_k_difficulty.png"):
    # 按 0-shot 成功数排序
    sorted_metrics = sorted(metrics_list)

    x = list(range(len(sorted_metrics)))  # x: 0~29
    x_labels = [f"Q{m.question_id:02d}" for m in sorted_metrics]

    y_0shot = [pass_at_k(m.correct_count_0shot, 1024, k) for m in sorted_metrics]
    y_4shot = [pass_at_k(m.correct_count_4shot, 1024, k) for m in sorted_metrics]
    y_8shot = [pass_at_k(m.correct_count_8shot, 1024, k) for m in sorted_metrics]

    plt.figure(figsize=(13, 6))

    # === 添加难度区块（Easy, Mid, Hard, Exh）===
    regions = [
        (0, 7, "Easy", "#d3f9d8"),   # 浅绿色
        (7, 15, "Mid", "#fff5cc"),   # 浅黄色
        (15, 23, "Hard", "#ffe0e0"), # 浅红色
        (23, 30, "Exh", "#e0e8ff"),  # 浅蓝色
    ]

    for start, end, label, color in regions:
        plt.axvspan(start, end, color=color, alpha=0.4)
        plt.text((start + end) / 2, 1.03, label, ha="center", va="bottom", fontsize=12, fontweight="bold")

    # === 画折线图 ===
    # base line(0-shot) 使用浅绿色的虚线
    plt.plot(x, y_0shot, label="0-shot", marker="o", color="green", linestyle="--", alpha=0.7)
    plt.plot(x, y_4shot, label="4-shot", marker="s", color="blue")
    plt.plot(x, y_8shot, label="8-shot", marker="^", color="red")

    plt.title(f"Pass@{k} across Questions with Difficulty Regions")
    plt.xlabel("Questions (sorted by 0-shot performance)")
    plt.ylabel(f"Pass@{k} Accuracy")
    plt.ylim(0, 1.05)
    plt.xticks(x, x_labels, rotation=45)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"图已保存至：{save_path}")

def main():
    metrics_list = [Metrics(question_id=i) for i in range(30)]

    base_dir = "/mnt/afs/codes/wangpy/limit-of-RLVR/math/examples/math_eval/EVAL/temp"
    load_and_count_from_dir(metrics_list, os.path.join(base_dir, "numshot0"), "0shot")
    load_and_count_from_dir(metrics_list, os.path.join(base_dir, "numshot4"), "4shot")
    load_and_count_from_dir(metrics_list, os.path.join(base_dir, "numshot8"), "8shot")

    # 输出结果简单检查
    for m in metrics_list:
        print(f"QID: {m.question_id:02d}, 0s: {m.correct_count_0shot}, 4s: {m.correct_count_4shot}, 8s: {m.correct_count_8shot}")
    # 绘制图表
    # for k in [1, 32, 256, 512, 1024]:
    #     plot_pass_at_k(metrics_list, k, save_path=f"FIGS/pass_at_{k}_plot.png")

if __name__ == "__main__":
    main()
