import os
import re
import sys
import json
import copy
import click
import pathlib
import tempfile
import multiprocessing
from functools import partial
from types import SimpleNamespace
from pathlib import Path, PurePosixPath
from shutil import copy2

import pandas as pd
from datetime import datetime
from huggingface_hub import HfApi, hf_hub_download

try:
    from datasets import load_dataset 
    HAS_DATASETS = True
except Exception:
    HAS_DATASETS = False


CATEGORY_MAP = {
    "ifeval": "IFEval",
    "bbh": "BBH",
    "math": "MATH",
    "gpqa": "GPQA",
    "musr": "MUSR",
    "mmlu_pro": "MMLU-Pro",
}

FIELD_MAP = {
    "bbh": "acc_norm",
    "gpqa": "acc_norm",
    "musr": "acc_norm",
    "mmlu_pro": "acc",
    "math": "exact_match",
}

SUBSETS = {
    "bbh": [
        "boolean_expressions","causal_judgement","date_understanding","disambiguation_qa","formal_fallacies",
        "geometric_shapes","hyperbaton","logical_deduction_five_objects","logical_deduction_seven_objects",
        "logical_deduction_three_objects","movie_recommendation","navigate","object_counting","penguins_in_a_table",
        "reasoning_about_colored_objects","ruin_names","salient_translation_error_detection","snarks",
        "sports_understanding","temporal_sequences","tracking_shuffled_objects_five_objects",
        "tracking_shuffled_objects_seven_objects","tracking_shuffled_objects_three_objects","web_of_lies"
    ],
    "math": ["algebra_hard","counting_and_prob_hard","geometry_hard","intermediate_algebra_hard",
             "num_theory_hard","prealgebra_hard","precalculus_hard"],
    "musr": ["murder_mysteries","object_placements","team_allocation"],
    "gpqa": ["diamond","extended","main"]
}

short_ds_and_models_to_remove = []


def construct_union_function():
    template = {
        "ifeval": None,
        "mmlu_pro": None,
        "bbh": {
            "boolean_expressions": None,"causal_judgement": None,"date_understanding": None,"disambiguation_qa": None,"formal_fallacies": None,
            "geometric_shapes": None,"hyperbaton": None,"logical_deduction_five_objects": None,"logical_deduction_seven_objects": None,
            "logical_deduction_three_objects": None,"movie_recommendation": None,"navigate": None,"object_counting": None,"penguins_in_a_table": None,
            "reasoning_about_colored_objects": None,"ruin_names": None,"salient_translation_error_detection": None,"snarks": None,
            "sports_understanding": None,"temporal_sequences": None,"tracking_shuffled_objects_five_objects": None,
            "tracking_shuffled_objects_seven_objects": None,"tracking_shuffled_objects_three_objects": None,"web_of_lies": None
        },
        "math": {"algebra_hard": None,"counting_and_prob_hard": None,"geometry_hard": None,"intermediate_algebra_hard": None,
                "num_theory_hard": None,"prealgebra_hard": None,"precalculus_hard": None},
        "musr": {"murder_mysteries": None,"object_placements": None,"team_allocation": None},
        "gpqa": {"diamond": None,"extended": None,"main": None}
    }
    union_data = copy.deepcopy(template)
    length_limit = copy.deepcopy(template)
    token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
    api = HfApi(token=token)
    model_name = "01-ai__Yi-34B"
    model_id = f"open-llm-leaderboard/{model_name}-details"
    try:
        files = api.list_repo_files(repo_id=model_id, repo_type="dataset")
        leaderboard_files = [c for c in files if "leaderboard" in c]

        for ds in CATEGORY_MAP.keys():
            data_from_old = None

            if ds not in SUBSETS:
                pat = re.compile(
                    rf"samples_leaderboard_{ds}_(\d{{4}}-\d\d-\d\dT\d\d-\d\d-\d\d\.\d{{6}})\.jsonl?$"
                )
                candidates = []
                for file in leaderboard_files:
                    m = pat.search(file)
                    if m:
                        candidates.append((file, m.group(1)))
                latest_file = get_latest_file(candidates)
                if latest_file:
                    with tempfile.TemporaryDirectory() as tmp:
                        fp = hf_hub_download(
                            repo_id=model_id, filename=latest_file, repo_type="dataset",
                            cache_dir=tmp, token=token
                        )
                        with open(fp, "r", encoding="utf-8") as f:
                            data_from_old = [json.loads(ln.strip()) for ln in f if ln.strip()]
                    union_data[ds] = data_from_old
                    length_limit[ds] = len(data_from_old)

            else:
                for subset in SUBSETS[ds]:
                    pat = re.compile(
                        rf"samples_leaderboard_{ds}_{subset}_(\d{{4}}-\d\d-\d\dT\d\d-\d\d-\d\d\.\d{{6}})\.jsonl?$"
                    )
                    candidates = []
                    for file in leaderboard_files:
                        m = pat.search(file)
                        if m:
                            candidates.append((file, m.group(1)))
                    latest_file = get_latest_file(candidates)
                    if latest_file:
                        with tempfile.TemporaryDirectory() as tmp:
                            fp = hf_hub_download(
                                repo_id=model_id, filename=latest_file, repo_type="dataset",
                                cache_dir=tmp, token=token
                            )
                            with open(fp, "r", encoding="utf-8") as f:
                                data_list = [json.loads(ln.strip()) for ln in f if ln.strip()]
                        union_data[ds][subset] = data_list
                        length_limit[ds][subset] = len(data_list)
    
    except Exception as e:
        print(f"Constructing union function failed! {e}")
        sys.exit(1)

    return union_data, length_limit

def id2hash_function(data_list: list, id):
    if data_list == []:
        raise ValueError("Empty list.")
    else:
        for item in data_list:
            if item['doc_id'] == id:
                return item['doc_hash']
        return -1


def compare_timestamps(ts1, ts2):
    def parse(ts):
        try:
            date_part, time_micro = ts.split('T')
            time_part, micro_part = time_micro.split('.')
            standard_time = time_part.replace('-', ':')
            standard_ts = f"{date_part} {standard_time}.{micro_part}"
            return datetime.strptime(standard_ts, "%Y-%m-%d %H:%M:%S.%f")
        except Exception:
            return None
    a, b = parse(ts1), parse(ts2)
    if a is None or b is None: return None
    return -1 if a < b else (1 if a > b else 0)

def get_latest_file(files_with_ts):
    latest_file, latest_ts = None, None
    for f, ts in files_with_ts:
        if latest_file is None:
            latest_file, latest_ts = f, ts
        else:
            cmp = compare_timestamps(ts, latest_ts)
            if cmp is not None and cmp > 0:
                latest_file, latest_ts = f, ts
    return latest_file

def model_name_from_repo(repo_id: str) -> str:
    """open-llm-leaderboard/xxx-details -> xxx"""
    name = repo_id
    if name.startswith("open-llm-leaderboard/"):
        name = name.replace("open-llm-leaderboard/", "")
    if name.endswith("-details"):
        name = name[:-8]
    return name

def cache_path(cache_dir: str, dataset: str, model_name: str) -> str:
    d = os.path.join(cache_dir, ".cache", dataset)
    os.makedirs(d, exist_ok=True)
    return os.path.join(d, f"{model_name}.csv")

def cache_exists(cache_dir: str, dataset: str, model_name: str) -> bool:
    return os.path.exists(cache_path(cache_dir, dataset, model_name))

def save_series_to_cache(cache_dir: str, dataset: str, series: pd.Series):
    path = cache_path(cache_dir, dataset, series.name)
    pd.DataFrame({"doc_hash": series.index, "value": series.values}).to_csv(path, index=False)

def load_series_from_cache(cache_dir: str, dataset: str, model_name: str) -> pd.Series:
    path = cache_path(cache_dir, dataset, model_name)
    df = pd.read_csv(path)
    return pd.Series(df["value"].values, index=df["doc_hash"].values, name=model_name)

def try_load_via_datasets(repo_id, repo_base, dataset, subset=None, token=None):
    if not HAS_DATASETS:
        return None
    if subset:
        cfg = f"{repo_base}__leaderboard_{dataset}_{subset}"
    else:
        cfg = f"{repo_base}__leaderboard_{dataset}"
    try:
        ds = load_dataset(repo_id, name=cfg, split="latest", token=token)
        return list(ds)
    except Exception:
        return None


def process_model(model, token, save_raw, debug, cache_dir, force, union_data, length_limit):
    global short_ds_and_models_to_remove
    api = HfApi(token=token)
    model_name = model_name_from_repo(model.id)

    model_results = {}     
    skipped_datasets = []   
    reasons = []             

    for ds in CATEGORY_MAP.keys():
        if (not force) and cache_exists(cache_dir, ds, model_name):
            skipped_datasets.append(ds)
            if debug: print(f"[CACHE] {model_name}\t{ds}")

    if len(skipped_datasets) == len(CATEGORY_MAP):
        return model_name, model_results, skipped_datasets, reasons

    try:
        files = api.list_repo_files(repo_id=model.id, repo_type="dataset")
        leaderboard_files = [c for c in files if "leaderboard" in c]

        if not leaderboard_files:
            reasons.append(f"{model_name}\tALL\tNO_LEADERBOARD_FILES_IN_LISTING")

        for ds in CATEGORY_MAP.keys():
            if ds in skipped_datasets:
                continue 

            data_from_old = None

            if ds not in SUBSETS:
                pat = re.compile(
                    rf"samples_leaderboard_{ds}_(\d{{4}}-\d\d-\d\dT\d\d-\d\d-\d\d\.\d{{6}})\.jsonl?$"
                )
                candidates = []
                for file in leaderboard_files:
                    m = pat.search(file)
                    if m:
                        candidates.append((file, m.group(1)))
                latest_file = get_latest_file(candidates)
                if latest_file:
                    with tempfile.TemporaryDirectory() as tmp:
                        fp = hf_hub_download(
                            repo_id=model.id, filename=latest_file, repo_type="dataset",
                            cache_dir=tmp, token=token
                        )
                        if save_raw:
                            raw_dir = Path(cache_dir) / "raw" / model_name
                            raw_dir.mkdir(parents=True, exist_ok=True)
                            copy2(fp, raw_dir / Path(PurePosixPath(latest_file).name))
                        with open(fp, "r", encoding="utf-8") as f:
                            data_from_old = [json.loads(ln.strip()) for ln in f if ln.strip()]
                    if debug:
                        print(f"[PICK] {model_name}/{ds} -> {Path(latest_file).name} ({len(data_from_old)})")

            else:
                merged = {}
                had_any_subset_old = False
                for subset in SUBSETS[ds]:
                    pat = re.compile(
                        rf"samples_leaderboard_{ds}_{subset}_(\d{{4}}-\d\d-\d\dT\d\d-\d\d-\d\d\.\d{{6}})\.jsonl?$"
                    )
                    candidates = []
                    for file in leaderboard_files:
                        m = pat.search(file)
                        if m:
                            candidates.append((file, m.group(1)))
                    latest_file = get_latest_file(candidates)
                    if latest_file:
                        had_any_subset_old = True
                        with tempfile.TemporaryDirectory() as tmp:
                            fp = hf_hub_download(
                                repo_id=model.id, filename=latest_file, repo_type="dataset",
                                cache_dir=tmp, token=token
                            )
                            if save_raw:
                                raw_dir = Path(cache_dir) / "raw" / model_name
                                raw_dir.mkdir(parents=True, exist_ok=True)
                                copy2(fp, raw_dir / Path(PurePosixPath(latest_file).name))
                            with open(fp, "r", encoding="utf-8") as f:
                                data_list = [json.loads(ln.strip()) for ln in f if ln.strip()]
                        if debug:
                            print(f"[PICK] {model_name}/{ds}/{subset} -> {Path(latest_file).name} ({len(data_list)})")
                        if len(data_list) == length_limit[ds][subset]:
                            for item in data_list:
                                doc_hash = id2hash_function(union_data[ds][subset], item["doc_id"])
                                field_value = 0 if item.get(FIELD_MAP[ds], 0) == 1 else 1
                                merged[doc_hash] = field_value
                        elif len(data_list) < length_limit[ds][subset]:
                            short_ds_and_models_to_remove.append([ds, subset, model_name])
                        else:
                            print(f"Found longer subset: subset is {ds}_{subset}, model_name is {model_name}")
                            sys.exit(1)

                if had_any_subset_old and merged:
                    s = pd.Series(merged, name=model_name)
                    model_results[ds] = s
                    save_series_to_cache(cache_dir, ds, s)
                    continue

            if ds not in SUBSETS and data_from_old is not None:
                res = {}
                if len(data_from_old) == length_limit[ds]:
                    for item in data_from_old:
                        doc_hash = id2hash_function(union_data[ds], item["doc_id"])
                        if ds != "ifeval":
                            val = 0 if item.get(FIELD_MAP[ds], 0) == 1 else 1
                        else:
                            psa = item.get("prompt_level_strict_acc")
                            isa = item.get("inst_level_strict_acc") or []
                            val = 0 if (psa and all(isa)) else 1
                        res[doc_hash] = val
                elif len(data_from_old) < length_limit[ds]:
                    short_ds_and_models_to_remove.append([ds, model_name])
                else:
                    print(f"Found longer dataset: dataset is {ds}, model_name is {model_name}")
                    sys.exit(1)
                if res:
                    s = pd.Series(res, name=model_name)
                    model_results[ds] = s
                    save_series_to_cache(cache_dir, ds, s)
                    continue 

            repo_base = model_name 
            if ds not in SUBSETS:
                data_list = try_load_via_datasets(model.id, repo_base, ds, subset=None, token=token)
                if data_list is None:
                    reasons.append(f"{model_name}\t{ds}\tNO_FILES_MATCH_PATTERN_AND_DATASETS_FAIL")
                else:
                    res = {}
                    if len(data_list) == length_limit[ds]:
                        for item in data_list:
                            doc_hash = id2hash_function(union_data[ds], item["doc_id"])
                            if ds != "ifeval":
                                val = 0 if item.get(FIELD_MAP[ds], 0) == 1 else 1
                            else:
                                psa = item.get("prompt_level_strict_acc")
                                isa = item.get("inst_level_strict_acc") or []
                                val = 0 if (psa and all(isa)) else 1
                            res[doc_hash] = val
                    elif len(data_list) < length_limit[ds]:
                        short_ds_and_models_to_remove.append([ds, model_name])
                    else:
                        print(f"Found longer dataset: dataset is {ds}, model_name is {model_name}")
                        sys.exit(1)
                    if res:
                        s = pd.Series(res, name=model_name)
                        model_results[ds] = s
                        save_series_to_cache(cache_dir, ds, s)
            else:
                merged = {}
                any_success = False
                for subset in SUBSETS[ds]:
                    data_list = try_load_via_datasets(model.id, repo_base, ds, subset=subset, token=token)
                    if data_list is None:
                        reasons.append(f"{model_name}\t{ds}/{subset}\tNO_FILES_MATCH_PATTERN_AND_DATASETS_FAIL")
                        continue
                    any_success = True
                    if len(data_list) == length_limit[ds][subset]:
                        for item in data_list:
                            doc_hash = id2hash_function(union_data[ds][subset], item["doc_id"])
                            val = 0 if item.get(FIELD_MAP[ds], 0) == 1 else 1
                            merged[doc_hash] = val
                    elif len(data_list) < length_limit[ds][subset]:
                        short_ds_and_models_to_remove.append([ds, subset, model_name])
                    else:
                        print(f"Found longer subset: subset is {ds}_{subset}, model_name is {model_name}")
                        sys.exit(1)
                if any_success and merged:
                    s = pd.Series(merged, name=model_name)
                    model_results[ds] = s
                    save_series_to_cache(cache_dir, ds, s)

        return model_name, model_results, skipped_datasets, reasons

    except Exception as e:
        reasons.append(f"{model_name}\tALL\tEXCEPTION:{e}")
        return model_name, model_results, skipped_datasets, reasons


def load_models_from_file(models_file):
    ids = []
    with open(models_file, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if ln and not ln.startswith("#"):
                ids.append(ln)
    return [SimpleNamespace(id=rid) for rid in ids]

@click.command()
@click.option("--output-dir", required=True, type=click.Path(path_type=pathlib.Path))
@click.option("--num-processes", default=4, show_default=True)
@click.option("--models-file", default="models_used.txt", type=click.Path(path_type=pathlib.Path),
              show_default=True, help="File with repo ids, one per line")
@click.option("--cache-dir", default=None, type=click.Path(path_type=pathlib.Path),
              help="Where to keep intermediate caches; default = output-dir")
@click.option("--force", is_flag=True, help="Ignore cache and re-compute")
@click.option("--save-raw", is_flag=True, help="Also copy raw JSON/JSONL to cache_dir/raw/{model}/")
@click.option("--debug", is_flag=True, help="Print debug picks/warns")
def main(output_dir, num_processes, models_file, cache_dir, force, save_raw, debug):

    union_data, length_limit = construct_union_function()


    token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
    if cache_dir is None:
        cache_dir = output_dir
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(cache_dir, ".cache"), exist_ok=True)

    print(f"Cache dir        : {cache_dir}")
    print(f"Force recompute  : {force}, Save raw: {save_raw}, Debug: {debug}, Datasets available: {HAS_DATASETS}")

    models = load_models_from_file(models_file)
    print(f"Loaded {len(models)} model repos from {models_file}")

    series_map = {k: {} for k in CATEGORY_MAP.keys()}

    for ds in CATEGORY_MAP.keys():
        cache_ds_dir = os.path.join(cache_dir, ".cache", ds)
        if os.path.isdir(cache_ds_dir):
            for fn in os.listdir(cache_ds_dir):
                if not fn.endswith(".csv"):
                    continue
                model_name = fn[:-4]
                try:
                    s = load_series_from_cache(cache_dir, ds, model_name)
                    series_map[ds][model_name] = s
                except Exception:
                    pass

    with multiprocessing.Pool(processes=num_processes) as pool:
        func = partial(
            process_model,
            token=token,
            save_raw=save_raw,
            debug=debug,
            cache_dir=cache_dir,
            force=force,
            union_data=union_data,
            length_limit=length_limit
        )

        import tqdm as _tqdm
        results = list(_tqdm.tqdm(pool.imap(func, models), total=len(models), desc="Processing models"))

    skipped_log = []
    reasons_all = []
    produced_any = set()

    for model_name, model_results, skipped_datasets, reasons in results:
        for ds in skipped_datasets:
            skipped_log.append(f"{model_name}\t{ds}")
        reasons_all.extend(reasons)

        has_any = False
        for ds, s in model_results.items():
            series_map[ds][s.name] = s 
            has_any = True
        if has_any:
            produced_any.add(model_name)

    for ds, mp in series_map.items():
        if mp:
            cols = list(mp.keys())
            df = pd.concat(mp.values(), axis=1, keys=cols).sort_index()
            out_csv = os.path.join(output_dir, f"{ds}_matrix.csv")
            df.to_csv(out_csv)
            print(f"{ds} saved, containing {len(cols)} model columns")
        else:
            print(f"{ds} does not contain any data")

    with open(os.path.join(output_dir, "skipped_cache.txt"), "w", encoding="utf-8") as f:
        for line in sorted(set(skipped_log)):
            f.write(line + "\n")

    failed_or_empty = []
    model_names_all = [model_name_from_repo(m.id) for m in models]
    for m in model_names_all:
        in_skip = any(sl.split("\t")[0] == m for sl in skipped_log)
        if (m not in produced_any) and (not in_skip):
            failed_or_empty.append(m)

    with open(os.path.join(output_dir, "failed_or_empty.txt"), "w", encoding="utf-8") as f:
        for m in sorted(set(failed_or_empty)):
            f.write(m + "\n")

    if reasons_all:
        with open(os.path.join(output_dir, "reasons_detail.txt"), "w", encoding="utf-8") as f:
            for r in reasons_all:
                f.write(r + "\n")
        summary = {}
        for r in reasons_all:
            typ = r.split("\t")[-1].split(":")[0]
            summary[typ] = summary.get(typ, 0) + 1
        print("Reason summary:", summary)

    if short_ds_and_models_to_remove:
        with open(os.path.join(output_dir, "short_ds_and_models_to_remove.json"), "w", encoding="utf-8") as f:
            json.dump(short_ds_and_models_to_remove, f, indent=4)

if __name__ == "__main__":
    main()