# gapKandO_merged.py
import argparse, json, math
from pathlib import Path
from typing import List, Dict

import torch
import numpy as np
import pandas as pd

# ---------- Helpers ----------
def get_label_from_name(name: str, class_map: dict) -> int:
    # image name format "<class>_xxxxx"
    if "_" not in name:
        raise ValueError(f"Image name does not contain '_' to infer class: {name}")
    return int(class_map[name.split("_")[0]])

def load_grouping(path: Path) -> Dict[str, List[str]]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def round4(x: float) -> float:
    return round(float(x), 4)

def load_train_probs_and_names(model: str, version: int):
    """
    Load training probabilities (softmax outputs) and corresponding image names
    for a given model/version following the dataset layout conventions.
    """
    tr_pth = f"{model}_imagenet_train.pth"
    d = torch.load(tr_pth, weights_only=True)
    probs_tr = d["probs"].float()
    img_tr = d["img_names"]
    return probs_tr, img_tr

def compute_err_vector(img_tr, probs_tr, class_list_path="class_list.json"):
    """
    Compute a binary error vector over training images: 1 if prediction != label, else 0.
    Labels are inferred from the image name prefix (before the underscore) using CLASS_MAP.
    """
    CLASS_MAP = {c: i for i, c in enumerate(json.load(open(class_list_path)))}
    pred_tr = probs_tr.argmax(1)
    labels_tr = torch.tensor([get_label_from_name(nm, CLASS_MAP) for nm in img_tr])
    err_tr = torch.ne(pred_tr, labels_tr).float()  # 1 if wrong, 0 if correct
    return err_tr

def model_tag(model: str, version: int) -> str:
    """
    Render a concise model tag that encodes the version as used in paths.
    """
    return model if version == 0 else (f"{model}_v1" if version == 3 else f"{model}_v{version}")

# ---------- Core computations ----------
def calc_gap_for_grouping(grouping: Dict[str, List[str]], name_to_idx: Dict[str, int],
                          err_tr: torch.Tensor, n: int, K_arg: int, delta: float):
    """
    Given a grouping (cluster -> list of image names), compute gap_O and gap_K.
    """
    # list of cluster ids that actually contain images
    cluster_ids = [k for k in sorted(grouping.keys(), key=lambda x: int(x)) if grouping[k]]

    # F(S_i, h) = mean(error) over images belonging to cluster i
    F_Si = []
    for cid in cluster_ids:
        idxs = [name_to_idx[nm] for nm in grouping[cid] if nm in name_to_idx]
        if not idxs:
            continue
        F_Si.append(err_tr[idxs].mean().item())

    if len(F_Si) == 0:
        return None  # no image matches

    ln_1_over_delta = math.log(1.0 / delta)
    coef_gapO = math.sqrt(ln_1_over_delta / (2.0 * n))
    ln_2K_over_delta = math.log(2.0 * K_arg / delta)

    T_size = len(F_Si)
    sum_f = float(np.sum(F_Si))
    a_hat = float(np.max(F_Si))

    # gap_O
    gapO = (1.0 + sum_f) * coef_gapO

    # gap_K
    term1 = (1.0 + math.sqrt(2.0) * a_hat) * math.sqrt(T_size * ln_2K_over_delta / n)
    term2 = (2.0 * T_size * ln_2K_over_delta) / n
    gapK = term1 + term2

    return {
        "T_size": T_size,
        "gapO": float(gapO),
        "gapK": float(gapK)
    }

# ---------- Modes ----------
def run_mode_seeds(args):
    MODEL = args.model_name
    version = int(args.version)
    K_arg = int(args.K)  # required
    seed_list = [int(s.strip()) for s in args.seeds.split(",") if s.strip()]
    tpl = args.train_group_json_tpl
    delta = float(args.delta)

    out_dir = Path(".")
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "gapOandK.csv"

    # load train probs + image names
    probs_tr, img_tr = load_train_probs_and_names(MODEL, version)
    n = probs_tr.shape[0]
    name_to_idx = {nm: i for i, nm in enumerate(img_tr)}
    err_tr = compute_err_vector(img_tr, probs_tr)

    gapO_list, gapK_list, T_sizes = [], [], []

    for seed in seed_list:
        grouping_path = Path(tpl.format(seed=seed))
        grouping = load_grouping(grouping_path)
        res = calc_gap_for_grouping(grouping, name_to_idx, err_tr, n, K_arg, delta)
        if res is None:
            # this seed matches no images => skip
            continue
        gapO_list.append(round4(res["gapO"]))
        gapK_list.append(round4(res["gapK"]))
        T_sizes.append(res["T_size"])

    if not gapO_list:
        raise RuntimeError("Failed to compute gap for any seed (grouping files may not match training image names).")

    gapO_mean = float(np.mean(gapO_list))
    gapO_std  = float(np.std(gapO_list))
    gapK_mean = float(np.mean(gapK_list))
    gapK_std  = float(np.std(gapK_list))
    T_size_mean = float(np.mean(T_sizes))

    row = {
        "model": model_tag(MODEL, version),
        "version": version,
        "K": K_arg,
        "delta": delta,
        "n": n,
        "|T|_mean": round4(T_size_mean),
        "gapO_mean": round4(gapO_mean),
        "gapO_std":  round4(gapO_std),
        "gapK_mean": round4(gapK_mean),
        "gapK_std":  round4(gapK_std),
        "gapO_list": gapO_list,
        "gapK_list": gapK_list
    }

    if out_csv.exists():
        pd.DataFrame([row]).to_csv(out_csv, mode="a", index=False, header=False)
    else:
        pd.DataFrame([row]).to_csv(out_csv, index=False)

def run_mode_byK(args):
    MODEL = args.model_name
    version = int(args.version)
    Ks = [int(s.strip()) for s in args.K.split(",") if s.strip()]
    seed_list = [int(s.strip()) for s in args.seeds.split(",") if s.strip()]
    seed = seed_list[0] if seed_list else 70
    tpl = args.train_group_json_tpl
    delta = float(args.delta)

    # load train probs + image names
    probs_tr, img_tr = load_train_probs_and_names(MODEL, version)
    n = probs_tr.shape[0]
    name_to_idx = {nm: i for i, nm in enumerate(img_tr)}
    err_tr = compute_err_vector(img_tr, probs_tr)

    gapO_byK, gapK_byK = [], []

    for K_arg in Ks:
        grouping_path = Path(tpl.format(K=K_arg, seed=seed))
        grouping = load_grouping(grouping_path)
        res = calc_gap_for_grouping(grouping, name_to_idx, err_tr, n, K_arg, delta)
        if res is None:
            raise RuntimeError(f"Seed {seed} with K={K_arg} matches no images in training set.")
        gapO_byK.append(round4(res["gapO"]))
        gapK_byK.append(round4(res["gapK"]))

    gapO_mean = round4(float(np.mean(gapO_byK)))
    gapO_std  = round4(float(np.std(gapO_byK)))
    gapK_mean = round4(float(np.mean(gapK_byK)))
    gapK_std  = round4(float(np.std(gapK_byK)))

    row = {
        "model": model_tag(MODEL, version),
        "gapO_mean": gapO_mean,
        "gapO_std":  gapO_std,
        "gapK_mean": gapK_mean,
        "gapK_std":  gapK_std,
        "gapO_list": gapO_byK,   # in the same order as K in --K
        "gapK_list": gapK_byK    # in the same order as K in --K
    }

    out_csv = Path("./gapOandK_byK_result.csv")
    df = pd.DataFrame([row])
    if out_csv.exists():
        df.to_csv(out_csv, mode="a", index=False, header=False)
    else:
        df.to_csv(out_csv, index=False)

# ---------- Main ----------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=["seeds", "byK"], required=True,
                        help="seeds: fix K, run multiple seeds; byK: fix seed, run multiple K")
    parser.add_argument("--model_name", required=True, type=str)

    # common args
    parser.add_argument("--version", default=1, type=int, help="Model outputs version (0,1,2,3)")
    parser.add_argument("--delta", type=float, default=0.01, help="Confidence parameter delta")

    # for mode seeds
    parser.add_argument("--K", type=str, default="200",
                        help="Mode 'seeds': a single integer K; Mode 'byK': a comma-separated list of K values")
    parser.add_argument("--seeds", type=str, default="50,65,70,83,100",
                        help="Comma-separated list of seeds (default: 50,65,70,83,100). For 'byK', first one is used (default 70 if empty).")

    # grouping path template:
    # - mode seeds: must contain {seed}, e.g., /.../train_k200_seed{seed}.json
    # - mode byK: must contain both {K} and {seed}, e.g., /.../train_k{K}_seed{seed}.json
    parser.add_argument("--train_group_json_tpl", required=True, type=str,
                        help="Path template for train grouping JSON files.")

    args = parser.parse_args()

    if args.mode == "seeds":
        # enforce K to be a single integer
        try:
            _ = int(args.K)
        except:
            raise ValueError("For --mode seeds, --K must be a single integer (e.g., 200).")
        run_mode_seeds(args)
    else:
        # byK: --K is a comma-separated list
        run_mode_byK(args)

if __name__ == "__main__":
    main()
