import numpy as np
import os
from torch import mode
from tqdm import tqdm
from collections import Counter
import sys
import pandas as pd
from typing import Iterable, Union, Any
from pathlib import Path
import json
import seaborn as sns


def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
    with open(file, "r", encoding="utf-8") as f:
        for line in f:
            try:
                yield json.loads(line)
            except:
                print("Error in loading:", line)
                exit()


dataset_list = "VAR_amc23,amc23"
# dataset_list = "VAR_aime24,aime24"
# dataset_list = "VAR_aime25,aime25"

N_samples = 16
N_boostrap = 1000
agg_mode = "mean"  # all or mean

dataset_list = dataset_list.split(",")
model_name_list = [
    ("Qwen/Qwen2.5-MATH-7B", "qwen25-math-cot"),
    ("PRIME-RL/Eurus-2-7B-PRIME", "qwen25-math-cot"),
    ("hkust-nlp/Qwen-2.5-Math-7B-SimpleRL-Zoo", "qwen25-math-cot"),
    ("sail/Qwen2.5-Math-7B-Oat-Zero", "qwen25-math-cot"),
    ("Skywork/Skywork-OR1-Math-7B", "deepseek-r1"),
    ("qihoo360/Light-R1-7B-DS", "deepseek-r1"),
    ]

model_name_list += [
    ("BytedTsinghua-SIA/DAPO-Qwen-32B", "qwen2.5-32B"),
    ("Kwaipilot/SRPO-Qwen-32B", "SRPO"),
    ("Qwen/Qwen2.5-32B", "qwen2.5-32B"), 
]

df_data = []

for dataset in dataset_list:

    for (model_name, prompt_type) in model_name_list:

        pre_path = f"VAR_score_pass_16/{model_name}/{dataset}"
        
        pattern = f"{prompt_type}_-1_seed0_t0.6_s0_e-1.jsonl"
        pass1_sample_filepath = [val for val in os.listdir(pre_path) if val.endswith(pattern)]
        assert len(pass1_sample_filepath) == 1, f"{model_name}: {pass1_sample_filepath}"
        pass1_sample_filepath = pass1_sample_filepath[0]
        pass1_sample_filepath = os.path.join(pre_path, pass1_sample_filepath)
        pass1_sample_list = sorted(list(load_jsonl(pass1_sample_filepath)), key=lambda x: x["idx"])  
        
        data_score_list = []
        for idx in range(N_boostrap):
            data = {}
            for pass1_sample in pass1_sample_list:
                q_id = str(pass1_sample["id"]).split("_")[0]
                if "VAR_" in dataset:
                    random_idx = np.random.choice(N_samples, 1)[0]
                else:
                    random_idx = idx % N_samples
                data.setdefault(q_id, []).append(pass1_sample["score"][random_idx])
            
            data_score = {}
            for q_id, scores in data.items():
                if agg_mode == "all":
                    data_score[q_id] = int(np.all(scores))
                elif agg_mode == "mean":
                    data_score[q_id] = np.mean(scores)
            data_score_list.append(np.mean(list(data_score.values())))

            if idx == 0:
                num_refined_samples = sum(len(val)>1 for val in data.values())
                num_samples = len(data)
            else:
                assert num_refined_samples == sum(len(val)>1 for val in data.values()), "Number of refined samples should be the same across all samples."
                assert num_samples == len(data), "Number of samples should be the same across all samples."
        
        # solved_problem_ids = ",".join([q_id for q_id, score in data_score.items() if score == 1])
        tmp_data = [model_name, dataset, prompt_type, np.mean(data_score_list), np.std(data_score_list), num_refined_samples, num_samples]
        df_data.append(tmp_data)

df = pd.DataFrame(df_data, columns=["model_name", "dataset", "prompt_type", "data_score", "data_score_std", "num_refined_samples", "num_samples"])
df["data_score"] = df["data_score"].apply(lambda x: round(x * 100, 2))
df["data_score_std"] = df["data_score_std"].apply(lambda x: round(x * 100, 2))
print("============ Detail infos ============")
print(df)

########################### Results Table ###########################
df = df[["model_name", "dataset", "data_score", "data_score_std"]]
df_pivot = df.pivot(index="model_name", columns="dataset", values=["data_score", "data_score_std"])

# Flatten the multi-level columns
df_pivot.columns = [f"{col[1]} ({col[0]})" for col in df_pivot.columns]

show_cols = []
for dataset in dataset_list:
    if dataset.startswith("VAR_"):
        comp_dataset = dataset[4:]
        df_pivot[f"Diff {comp_dataset}"] = (
            (df_pivot[f"{dataset} (data_score)"] - df_pivot[f"{comp_dataset} (data_score)"]) / 
            df_pivot[f"{comp_dataset} (data_score)"] * 100
        )
        show_cols.extend([
            f"{comp_dataset} (data_score)", f"{comp_dataset} (data_score_std)",
            f"{dataset} (data_score)", f"{dataset} (data_score_std)", 
            f"Diff {comp_dataset}"
        ])
df_pivot[show_cols] = df_pivot[show_cols].applymap(lambda x: round(x, 1) if isinstance(x, (int, float)) else x)

print("============ Results Table ============")
print(df_pivot[show_cols])
print(df_pivot[show_cols].mean())