import argparse, json, math
from pathlib import Path
from typing import List, Dict, Optional

import torch
import numpy as np
import pandas as pd
from openpyxl import load_workbook


# ===========================================================
# ========================= HELPERS =========================
# ===========================================================
def append_rows_to_xlsx(xlsx_path: Path, rows: List[dict], sheet_name: str):
    df = pd.DataFrame(rows)
    if xlsx_path.exists():
        book = load_workbook(xlsx_path)
        if sheet_name in book.sheetnames:
            startrow = book[sheet_name].max_row
            with pd.ExcelWriter(xlsx_path, engine="openpyxl", mode="a", if_sheet_exists="overlay") as writer:
                df.to_excel(writer, sheet_name=sheet_name, index=False, header=(startrow == 0), startrow=startrow)
        else:
            with pd.ExcelWriter(xlsx_path, engine="openpyxl", mode="a") as writer:
                df.to_excel(writer, sheet_name=sheet_name, index=False)
    else:
        df.to_excel(xlsx_path, index=False, sheet_name=sheet_name)


def get_label_from_name(name: str, class_map: dict):
    if "_" not in name:
        raise ValueError(f"Image name does not contain '_' to infer class: {name}")
    return class_map[name.split("_")[0]]


def load_grouping(path: Path) -> Dict[str, List[str]]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def safe_mean_std(values: List[Optional[float]]):
    arr = np.array([np.nan if v is None else float(v) for v in values], dtype=float)
    return float(np.nanmean(arr)), float(np.nanstd(arr, ddof=0))


def round4(x: float) -> float:
    return round(float(x), 4)


def fmt_mean_std(mean: float, std: float) -> str:
    return f"{mean:.4f} ± {std:.4f}"


def load_val_for_test_error(MODEL: str, seed: int):
    """ Load val inference to compute test_error. """
    val_pth = f"{MODEL}_imagenet_val.pth"
    d_val = torch.load(val_pth, weights_only=True)
    acc_val = d_val["accuracies"]
    img_val = d_val["img_names"]
    with open(f"test40k_list_seed{seed}.json", "r") as f:
        test_set = set(json.load(f))
    list_test_loss = [int(1 - acc) for name, acc in zip(img_val, acc_val) if name in test_set]
    loss_test = sum(list_test_loss) / len(list_test_loss) if len(list_test_loss) > 0 else float("nan")
    return round4(loss_test)


def load_train_err_and_map(MODEL: str, version: int, CLASS_MAP: dict):
    """ Load train inference to get train_error. """
    tr_pth = f"{MODEL}_imagenet_train.pth"
    d = torch.load(tr_pth, weights_only=True)
    probs = d["probs"].float()
    img = d["img_names"]
    pred = probs.argmax(1)
    labels = torch.tensor([get_label_from_name(nm, CLASS_MAP) for nm in img])
    err = torch.ne(pred, labels).float()
    MAP = {name: i for i, name in enumerate(img)}
    return err, probs.shape[0], MAP, len(img), img


def load_val_err_and_map(MODEL: str, version: int):
    """ Load val inference to get val_error. """
    val_pth = f"{MODEL}_imagenet_val.pth"
    d_val = torch.load(val_pth, weights_only=True)
    acc_val = d_val["accuracies"]
    img_val = d_val["img_names"]
    err = torch.tensor([1 - float(a) for a in acc_val], dtype=torch.float32)
    n = 10000
    MAP = {name: i for i, name in enumerate(img_val)}
    return err, n, MAP, len(img_val), img_val


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, choices=["train", "val"],
                        help="Choose whether to run on TRAIN clusters or VAL clusters.")
    parser.add_argument("--model_name", required=True, type=str)
    parser.add_argument("--K", default=200, type=int)
    parser.add_argument("--version", default=1, type=int)
    parser.add_argument("--group_json_tpl", required=True, type=str,
                        help="Template path to the grouping JSON (for TRAIN or VAL), e.g., /path/group_k{K}_seed{seed}.json")
    parser.add_argument("--seeds", type=str, default="50,65,70,83,100")
    args = parser.parse_args()

    MODE = args.mode
    MODEL = args.model_name
    K_arg, version = args.K, args.version
    seed_list = [int(s.strip()) for s in args.seeds.split(",") if s.strip()]
    tpl = args.group_json_tpl

    # Constants / Hyperparameters
    delta, alpha, gamma, C = 0.01, 100, 0.04 ** (-1.0 / 100), 1.0
    out_dir = Path(".")
    out_dir.mkdir(parents=True, exist_ok=True)

    # Class list
    class_list_path = "class_list.json"
    CLASS_MAP = {c: i for i, c in enumerate(json.load(open(class_list_path)))}

    # Train error & MAP by mode
    if MODE == "train":
        err_vec, n, MAP, _, _ = load_train_err_and_map(MODEL, version, CLASS_MAP)
    else:
        err_vec, n, MAP, _, _ = load_val_err_and_map(MODEL, version)

    A1 = err_vec.mean().item()

    # Load test error
    test_error = load_val_for_test_error(MODEL, version)

    # Compute bound terms
    b = math.sqrt(0.5 * n * math.log(2.0 / delta))
    C_sqrt_ln_term = C * math.sqrt(math.log(2.0 / delta) / (2.0 * n))
    ln_gamma = math.log(gamma)

    # Iterate multiple seeds
    stats = []
    FSi_per_seed = {}
    max_K_seen = K_arg

    for seed in seed_list:
        grouping = load_grouping(Path(tpl.format(seed=seed)))
        # Only keep clusters that are non-empty
        cluster_ids = [k for k in sorted(grouping.keys(), key=lambda x: int(x)) if grouping[k]]

        F_Si_dict, n_i = {}, []
        for cid in cluster_ids:
            idxs = [MAP[nm] for nm in grouping[cid] if nm in MAP]
            if not idxs:
                continue
            F_Si_dict[int(cid)] = err_vec[idxs].mean().item()
            n_i.append(len(idxs))

        # Extend K_all to cover cluster ids possibly exceeding K_arg
        K_all = max(K_arg, max(map(int, F_Si_dict.keys())) + 1 if F_Si_dict else K_arg)
        max_K_seen = max(max_K_seen, K_all)
        FSi_per_seed[seed] = [F_Si_dict.get(i, None) for i in range(K_all)]

        # Terms for the bound (based on cluster sizes)
        nis = np.array(n_i, dtype=float) if n_i else np.zeros(1)
        sum_nisq = float(np.sum((nis / n) ** 2))
        u_hat = (gamma*(1+2*b)/(2*n) + gamma*len(F_Si_dict)*(b**2)/(2*n**2)
                 + (gamma**2)/2*sum_nisq + (gamma**2)*math.sqrt(math.log(2/delta)/(2*n)))
        sum_FSi = sum(v for v in F_Si_dict.values())
        A2 = (b/n)*sum_FSi
        A3 = C*math.sqrt(max(u_hat*alpha*ln_gamma, 0.0)) + C_sqrt_ln_term

        stats.append({
            "u_mu": u_hat, "A2": A2, "A3": A3,
            "Unc(Gamma)": A2 + A3, "bound5_rhs": A1 + A2 + A3,
            "T_size": len(F_Si_dict), "sum_FSi": sum_FSi,
            "sum_nisq": sum_nisq,
        })

    # Summary statistics
    model_tag = MODEL if version == 0 else (f"{MODEL}_v1" if version == 3 else f"{MODEL}_v{version}")
    summary_statistics = {
        "model": model_tag,
        "seeds": ",".join(map(str, seed_list)),
        "n": n, "delta": delta, "alpha": alpha, "gamma": gamma, "C": C,
        "b": round4(b), "C_sqrt_ln_term": round4(C_sqrt_ln_term), "A1": round4(A1),
        "test_error": test_error
    }

    # keys to record as mean±std
    keys_mean_std = ["A2", "A3", "u_mu", "Unc(Gamma)", "T_size", "sum_FSi", "sum_nisq", "bound5_rhs"]

    summary_row = dict(summary_statistics)
    for k in keys_mean_std:
        arr = np.array([s[k] for s in stats], dtype=float)
        mu = float(np.mean(arr))
        sig = float(np.std(arr, ddof=0))
        summary_row[k] = fmt_mean_std(mu, sig)

    # Excel
    xlsx_path = out_dir / "computed_bound5.xlsx"
    sheet_name = "TRAIN" if MODE == "train" else "VAL"
    append_rows_to_xlsx(xlsx_path, [summary_row], sheet_name)

    # JSON: per-cluster mean/std of F(Si,h)
    FSi_mean, FSi_std = [], []
    for i in range(max_K_seen):
        vals = [lst[i] if i < len(lst) else None for lst in FSi_per_seed.values()]
        mu, sig = safe_mean_std(vals)
        FSi_mean.append(round4(mu)); FSi_std.append(round4(sig))

    json_path = out_dir / "final_F_Si.json"
    store = {}
    if json_path.exists():
        try:
            store = json.load(open(json_path, "r", encoding="utf-8"))
        except json.JSONDecodeError:
            store = {}
    store[model_tag] = {"mean": FSi_mean, "std": FSi_std}
    json.dump(store, open(json_path, "w", encoding="utf-8"), ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()
