import numpy as np
import os
from tqdm import tqdm
from collections import Counter
import sys
import pandas as pd
from typing import Iterable, Union, Any
from pathlib import Path
import json

def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
    with open(file, "r", encoding="utf-8") as f:
        for line in f:
            try:
                yield json.loads(line)
            except:
                print("Error in loading:", line)
                exit()

dataset_list = "VAR_amc23,amc23"
dataset_list = "VAR_aime24,aime24"
dataset_list = "VAR_aime25,aime25"
agg_mode = "mean"

dataset_list = dataset_list.split(",")
model_name_list = [
    ("DeepSeek-R1-0528", "none"),
    ("SEED-THINK-v1.6", "none"),
    ("Qwen3-235B-A22B", "none"),
    ("OpenAI-o4-mini-high", "none"),
    ]

N_samples = 4
N_boostrap = 1000

df_data = []
for dataset in dataset_list:
    for (model_name, prompt_type) in model_name_list:
        trial_pass1_sample_list = []
        for trial in range(1, N_samples+1):
            pre_path = f"VAR_score_pass_1_{trial}/{model_name}/{dataset}"
            pattern = f"{prompt_type}_-1_seed0_t0.0_s0_e-1.jsonl"
            pass1_sample_filepath = [val for val in os.listdir(pre_path) if val.endswith(pattern)]
            assert len(pass1_sample_filepath) == 1, f"{model_name}: {pass1_sample_filepath}"
            pass1_sample_filepath = pass1_sample_filepath[0]
            pass1_sample_filepath = os.path.join(pre_path, pass1_sample_filepath)
            pass1_sample_list = sorted(list(load_jsonl(pass1_sample_filepath)), key=lambda x: x["idx"])
            trial_pass1_sample_list.append(pass1_sample_list)
        
        assert np.all([len(trial_pass1_sample_list[0]) == len(trial_pass1_sample_list[i]) for i in range(1, N_samples)])
        
        score_bootstrap = []
        for bi in range(N_boostrap):
            data = {}
            for sample_i in range(len(pass1_sample_list)):
                random_idx = np.random.choice(N_samples, 1)[0]
                pass1_sample = trial_pass1_sample_list[random_idx][sample_i]
                q_id = str(pass1_sample["id"]).split("_")[0]
                data.setdefault(q_id, []).append(pass1_sample["score"][0])
            data_score = {}
            for q_id, scores in data.items():
                if agg_mode == "all":
                    data_score[q_id] = int(np.all(scores))
                elif agg_mode == "mean":
                    data_score[q_id] = np.mean(scores)
            score_bootstrap.append(np.mean(list(data_score.values())))

        num_refined_samples = sum(len(val)>1 for val in data.values())
        num_samples = len(data)

        tmp_data = [model_name, dataset, prompt_type, np.mean(score_bootstrap), np.std(score_bootstrap), num_refined_samples, num_samples]
        df_data.append(tmp_data)

df = pd.DataFrame(df_data, columns=["model_name", "dataset", "prompt_type", "data_score", "data_score_std", "num_refined_samples", "num_samples"])
df["data_score"] = df["data_score"].apply(lambda x: round(x * 100, 2))
df["data_score_std"] = df["data_score_std"].apply(lambda x: round(x * 100, 2))
print("============ Detail infos ============")
print(df)

########################### Results Table ###########################
df = df[["model_name", "dataset", "data_score", "data_score_std"]]
df_pivot = df.pivot(index="model_name", columns="dataset", values=["data_score", "data_score_std"])

# Flatten the multi-level columns
df_pivot.columns = [f"{col[1]} ({col[0]})" for col in df_pivot.columns]

show_cols = []
for dataset in dataset_list:
    if dataset.startswith("VAR_"):
        comp_dataset = dataset[4:]
        df_pivot[f"Diff {comp_dataset}"] = (
            (df_pivot[f"{dataset} (data_score)"] - df_pivot[f"{comp_dataset} (data_score)"]) / 
            df_pivot[f"{comp_dataset} (data_score)"] * 100
        )
        show_cols.extend([
            f"{comp_dataset} (data_score)", f"{comp_dataset} (data_score_std)",
            f"{dataset} (data_score)", f"{dataset} (data_score_std)", 
            f"Diff {comp_dataset}"
        ])
df_pivot[show_cols] = df_pivot[show_cols].applymap(lambda x: round(x, 1) if isinstance(x, (int, float)) else x)

print("============ Results Table ============")
print(df_pivot[show_cols])
print(df_pivot[show_cols].mean())