#!/usr/bin/env python3
import os, json, math, csv, glob, random, argparse
from typing import Dict, List, Tuple
import numpy as np
import torch

# ---------- constants ----------
SEED = 1337
DELTA = 0.01
ALPHA = 100.0
GAMMA_THEO = (0.04) ** (-1.0 / ALPHA)   # γ = 0.04^(-1/α)
K = 75                                   # default expected number of clusters
VAL_SUBSET_SIZE = 1000                   # from the 5k val

# ---------- math helpers ----------
def compute_b(n: int, delta: float) -> float:
    return math.sqrt(0.5 * n * math.log(2.0 / delta))


def compute_uhat_td25(n: int, n_i_list: List[int], b: float, gamma: float, delta: float, T_size: int) -> float:
    term1 = gamma * (1.0 + 2.0 * b) / (2.0 * n)
    term2 = (gamma * T_size * (b ** 2)) / (2.0 * (n ** 2))
    frac_sq_sum = sum((ni / n) ** 2 for ni in n_i_list)
    term3 = (gamma ** 2) / 2.0 * frac_sq_sum
    term4 = (gamma ** 2) * math.sqrt(math.log(2.0 / delta) / (2.0 * n))
    return term1 + term2 + term3 + term4


def compute_A3(C: float, uhat: float, alpha: float, gamma: float, delta: float, n: int) -> float:
    return C * math.sqrt(max(0.0, uhat * alpha * math.log(gamma))) + C * math.sqrt(math.log(2.0 / delta) / (2.0 * n))


def compute_uhat_old(n: int, n_i_list: List[int], gamma: float, delta: float, K: int) -> float:
    frac_sq_sum = sum((ni / n) ** 2 for ni in n_i_list)
    return (gamma / (2.0 * n)) + (gamma ** 2) / 2.0 * frac_sq_sum + (gamma ** 2) * math.sqrt((2.0 / n) * math.log(2.0 * K / delta))


def compute_g2_old(C: float, n: int, K: int, delta: float, T_size: int, n_i_list: List[int]) -> float:
    ln = math.log(2.0 * K / delta)
    sum_sqrt = sum(math.sqrt(ni) for ni in n_i_list if ni > 0)
    termA = C * (1.0 + math.sqrt(2.0)) * math.sqrt(ln) * (sum_sqrt / n)
    termB = (4.0 * C * T_size * ln) / n
    return termA + termB


# --- helpers ---
def _pm_std(values) -> float:
    arr = np.asarray(list(values), dtype=float)
    if arr.size <= 1:
        return 0.0
    return float(np.std(arr, ddof=1))


def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)


def init_bound_td25_csv(path: str, seeds: List[int]):
    if os.path.exists(path): return
    header = [
        "model","split","loss_type","A1","C","b","gamma","alpha","delta",
        "Bound5_mean","Bound5_pm","Unc_mean","Unc_pm","n","T_mean","T_pm",
        "train_loss_all","val_loss_1k","test_loss_4k",
    ]
    for s in seeds:
        header += [f"A2_seed_{s}", f"A3_seed_{s}", f"Bound5_seed_{s}", f"Unc_seed_{s}", f"uhat_seed_{s}", f"T_size_seed_{s}"]
    with open(path, "w", newline="") as f:
        csv.writer(f).writerow(header)


def init_bound_old_csv(path: str, seeds: List[int]):
    if os.path.exists(path): return
    header = [
        "model","split","loss_type","A1","C","gamma","alpha","delta",
        "Bound3_mean","Bound3_pm","Term2_mean","Term2_pm","g2_mean","g2_pm","n","T_mean","T_pm",
        "train_loss_all","val_loss_1k","test_loss_4k",
    ]
    for s in seeds:
        header += [f"term2_seed_{s}", f"g2_seed_{s}", f"Bound3_seed_{s}", f"uhat_old_seed_{s}", f"T_size_seed_{s}"]
    with open(path, "w", newline="") as f:
        csv.writer(f).writerow(header)


def init_metrics_csv(path: str):
    if os.path.exists(path): return
    with open(path, "w", newline="") as f:
        csv.writer(f).writerow(["model", "train_loss_all", "val_loss_1k", "test_loss_4k", "C", "alpha", "gamma", "delta"])


def append_row(path: str, row: List):
    with open(path, "a", newline="") as f: csv.writer(f).writerow(row)


# ---------- IO ----------
def load_inference_losses(pth_path: str) -> Dict[int, float]:
    d = torch.load(pth_path, map_location="cpu")
    ids = list(map(int, d["image_ids"]))
    mious = [float(x) for x in d["mean_ious"]]
    assert len(ids) == len(mious), f"length mismatch in {pth_path}"
    return {img_id: 1.0 - miou for img_id, miou in zip(ids, mious)}


def pick_val_subset(all_val_ids: List[int], k: int = VAL_SUBSET_SIZE) -> Tuple[List[int], List[int]]:
    rng = random.Random(SEED)
    subset = rng.sample(all_val_ids, k)
    subset_set = set(subset)
    rest = [i for i in all_val_ids if i not in subset_set]
    return subset, rest


def read_groupings(group_dir: str, K: int) -> Tuple[List[int], Dict[int, Dict[str, Dict[int, List[int]]]]]:
    paths = sorted(glob.glob(os.path.join(group_dir, f"grouping_K_{K}_seed_*.json")))
    if not paths:
        raise FileNotFoundError(f"No grouping JSONs found in {group_dir} (expected 'grouping_K_{K}_seed_*.json').")
    seeds = []
    store: Dict[int, Dict[str, Dict[int, List[int]]]] = {}
    for p in paths:
        seed = int(os.path.basename(p).split("_")[-1].replace(".json", ""))
        seeds.append(seed)
        with open(p, "r") as f:
            raw = json.load(f)
        by_split = {"train": {}, "val": {}}
        for k_str, tv in raw.items():
            cid = int(k_str)
            by_split["train"][cid] = [int(x) for x in tv.get("train", [])]
            by_split["val"][cid]   = [int(x) for x in tv.get("val", [])]
        store[seed] = by_split
    return seeds, store


def intersect_ids(lst: List[int], allowed: set) -> List[int]:
    return [i for i in lst if i in allowed]


def compute_per_seed_terms(
    model_name: str,
    split_name: str,
    loss_type: str,
    ids_in_split: List[int],
    losses_map: Dict[int, float],
    groupings_by_seed: Dict[int, Dict[str, Dict[int, List[int]]]],
    C: float,
    gamma: float,
    alpha: float,
    delta: float,
    seeds: List[int],
    old_bound: bool = False,
) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float], Dict[int, float], Dict[int, int], float, float, int]:
    """
    For a given split (train or val), compute per-seed:
      - TD25: A2, A3, Bound5, Unc, T_sizes
      - OLD : term2, g2, Bound3, uhat_old, T_sizes
    Returns:
      per_seed_A (or term2), per_seed_A3_or_g2, per_seed_Bound, per_seed_unc_or_uhat,
      T_sizes, A1, b, n
    """
    id_set = set(ids_in_split)
    n = len(ids_in_split)
    assert n > 0, f"Empty split for {model_name}:{split_name}"
    b = compute_b(n, delta)
    A1 = float(np.mean([losses_map[i] for i in ids_in_split]))  # mean loss over S

    per1: Dict[int, float] = {}  # A2 or term2
    per2: Dict[int, float] = {}  # A3 or g2
    perB: Dict[int, float] = {}  # Bound5 or Bound3
    perU: Dict[int, float] = {}  # Unc (TD25) or û_old (OLD)
    T_sizes: Dict[int, int] = {}

    for s in seeds:
        g = groupings_by_seed[s][split_name]
        F_Si = []
        n_i_list = []
        for cid in range(K):
            ids_c = intersect_ids(g.get(cid, []), id_set)
            n_i = len(ids_c)
            n_i_list.append(n_i)
            if n_i > 0:
                F_Si.append(float(np.mean([losses_map[i] for i in ids_c])))

        T_size = sum(1 for n_i in n_i_list if n_i > 0)
        T_sizes[s] = T_size

        if not old_bound:
            uhat = compute_uhat_td25(n, n_i_list, b, gamma, delta, T_size)
            A2   = (b / n) * float(np.sum(F_Si))
            A3   = compute_A3(C, uhat, alpha, gamma, delta, n)
            B5   = A1 + A2 + A3
            per1[s] = A2; per2[s] = A3; perB[s] = B5; perU[s] = uhat
        else:
            u_old = compute_uhat_old(n, n_i_list, gamma, delta, K)
            term2 = C * math.sqrt(max(0.0, u_old * alpha * math.log(gamma)))
            g2    = compute_g2_old(C, n, K, delta, T_size, n_i_list)
            B3    = A1 + term2 + g2
            per1[s] = term2; per2[s] = g2; perB[s] = B3; perU[s] = u_old

    return per1, per2, perB, perU, T_sizes, A1, b, n


def _read_last_row_as_dict(csv_path: str) -> Dict[str, str]:
    with open(csv_path, "r", newline="") as f:
        reader = csv.DictReader(f)
        last = None
        for row in reader:
            last = row
    if last is None:
        raise ValueError(f"No data rows in {csv_path}")
    return last


def _avg_seed_cols(row: Dict[str, str], prefix: str) -> float:
    vals = [float(v) for k, v in row.items() if k.startswith(prefix)]
    return float(np.mean(vals)) if vals else float("nan")


# ---------- main ----------
def main():
    parser = argparse.ArgumentParser(description="Compute TD25 and OLD bounds for segmentation models.")
    parser.add_argument("--inference_dir", type=str, required=True,
                        help="Base directory containing model subfolders with train_inference.pth and val_inference.pth.")
    parser.add_argument("--group_dir", type=str, required=True,
                        help="Directory containing grouping_K_{K}_seed_{s}.json files.")
    parser.add_argument("--out_dir", type=str, required=True,
                        help="Where to save results (CSV summaries).")
    parser.add_argument("--K", type=int, default=75, help="Number of clusters (default: 75).")
    parser.add_argument("--delta", type=float, default=0.01, help="Confidence parameter δ.")
    args = parser.parse_args()

    random.seed(SEED)
    ensure_dir(args.out_dir)

    # Discover models
    model_dirs = [p for p in glob.glob(os.path.join(args.inference_dir, "*")) if os.path.isdir(p)]
    if not model_dirs:
        raise FileNotFoundError(f"No model folders found under {args.inference_dir}")

    # Load groupings
    seeds, groupings = read_groupings(args.group_dir, args.K)
    print(f"Found grouping seeds: {seeds}")

    processed = []
    for mdir in model_dirs:
        model = os.path.basename(mdir.rstrip("/"))
        train_pth = os.path.join(mdir, "train_inference.pth")
        val_pth   = os.path.join(mdir, "val_inference.pth")

        if not (os.path.exists(train_pth) and os.path.exists(val_pth)):
            print(f"Skipping {model}, missing inference files.")
            continue

        print(f"\n=== Model: {model} ===")
        out_mdir = os.path.join(args.out_dir, model)
        ensure_dir(out_mdir)

        metrics_csv           = os.path.join(out_mdir, "metrics.csv")
        bound_train_td25_csv  = os.path.join(out_mdir, "bound_train_miou.csv")
        bound_val_td25_csv    = os.path.join(out_mdir, "bound_val_miou.csv")
        bound_train_old_csv   = os.path.join(out_mdir, "bound_train_miou_old.csv")
        bound_val_old_csv     = os.path.join(out_mdir, "bound_val_miou_old.csv")

        init_metrics_csv(metrics_csv)
        init_bound_td25_csv(bound_train_td25_csv, seeds)
        init_bound_td25_csv(bound_val_td25_csv,   seeds)
        init_bound_old_csv(bound_train_old_csv,   seeds)
        init_bound_old_csv(bound_val_old_csv,     seeds)

        train_losses = load_inference_losses(train_pth)
        val_losses   = load_inference_losses(val_pth)

        train_ids_all = sorted(train_losses.keys())
        val_ids_all   = sorted(val_losses.keys())
        val_1k_ids, test_4k_ids = pick_val_subset(val_ids_all, VAL_SUBSET_SIZE)

        train_loss_all = float(np.mean([train_losses[i] for i in train_ids_all]))
        val_loss_1k    = float(np.mean([val_losses[i] for i in val_1k_ids]))
        test_loss_4k   = float(np.mean([val_losses[i] for i in test_4k_ids]))
        C = 1.0

        append_row(metrics_csv, [model, f"{train_loss_all:.6f}", f"{val_loss_1k:.6f}", f"{test_loss_4k:.6f}",
                                 f"{C:.6f}", f"{ALPHA:.6f}", f"{GAMMA_THEO:.8f}", f"{args.delta:.6f}"])

        # ---------- TD25 '(5)' ----------
        for split_name, ids_in_split, losses_map, out_csv in [
            ("train", train_ids_all, train_losses, bound_train_td25_csv),
            ("val",   val_1k_ids,    val_losses,   bound_val_td25_csv),
        ]:
            perA2, perA3, perB5, perUhat, T_sizes, A1, b, n = compute_per_seed_terms(
                model_name=model,
                split_name=split_name,
                loss_type="miou_loss",
                ids_in_split=ids_in_split,
                losses_map=losses_map,
                groupings_by_seed=groupings,
                C=C, gamma=GAMMA_THEO, alpha=ALPHA, delta=DELTA,
                seeds=seeds,
                old_bound=False,
            )
            b5_vals  = list(perB5.values())
            unc_vals = [perA2[s] + perA3[s] for s in seeds]
            t_vals   = list(T_sizes.values())

            B5_mean  = float(np.mean(b5_vals)); B5_pm  = _pm_std(b5_vals)
            Unc_mean = float(np.mean(unc_vals)); Unc_pm = _pm_std(unc_vals)
            T_mean   = float(np.mean(t_vals));   T_pm   = _pm_std(t_vals)

            row = [
                model, split_name, "miou_loss",
                f"{A1:.6f}", f"{C:.6f}", f"{b:.6f}",
                f"{GAMMA_THEO:.8f}", f"{ALPHA:.6f}", f"{DELTA:.6f}",
                f"{B5_mean:.6f}", f"{B5_pm:.6f}", f"{Unc_mean:.6f}", f"{Unc_pm:.6f}",
                f"{n:d}", f"{T_mean:.2f}", f"{T_pm:.2f}",
                f"{train_loss_all:.6f}", f"{val_loss_1k:.6f}", f"{test_loss_4k:.6f}",
            ]
            for s in seeds:
                row += [f"{perA2[s]:.6f}", f"{perA3[s]:.6f}",
                        f"{perB5[s]:.6f}", f"{(perA2[s]+perA3[s]):.6f}",
                        f"{perUhat[s]:.6f}", f"{T_sizes[s]:d}"]
            append_row(out_csv, row)
            print(f"  TD25 '{split_name}': Bound5_mean={B5_mean:.4f} ±{B5_pm:.4f}, Unc_mean={Unc_mean:.4f} ±{Unc_pm:.4f}, n={n}, |T|~{T_mean:.1f} ±{T_pm:.1f}")

        # ---------- OLD '(3)' ----------
        for split_name, ids_in_split, losses_map, out_csv in [
            ("train", train_ids_all, train_losses, bound_train_old_csv),
            ("val",   val_1k_ids,    val_losses,   bound_val_old_csv),
        ]:
            perT2, perG2, perB3, perUold, T_sizes, A1, b, n = compute_per_seed_terms(
                model_name=model,
                split_name=split_name,
                loss_type="miou_loss",
                ids_in_split=ids_in_split,
                losses_map=losses_map,
                groupings_by_seed=groupings,
                C=C, gamma=GAMMA_THEO, alpha=ALPHA, delta=DELTA,
                seeds=seeds,
                old_bound=True,
            )
            b3_vals = list(perB3.values()); t2_vals = list(perT2.values()); g2_vals = list(perG2.values()); t_vals = list(T_sizes.values())

            B3_mean = float(np.mean(b3_vals)); B3_pm = _pm_std(b3_vals)
            T2_mean = float(np.mean(t2_vals)); T2_pm = _pm_std(t2_vals)
            G2_mean = float(np.mean(g2_vals)); G2_pm = _pm_std(g2_vals)
            T_mean  = float(np.mean(t_vals));  T_pm  = _pm_std(t_vals)

            row = [
                model, split_name, "miou_loss",
                f"{A1:.6f}", f"{C:.6f}",
                f"{GAMMA_THEO:.8f}", f"{ALPHA:.6f}", f"{DELTA:.6f}",
                f"{B3_mean:.6f}", f"{B3_pm:.6f}",
                f"{T2_mean:.6f}", f"{T2_pm:.6f}",
                f"{G2_mean:.6f}", f"{G2_pm:.6f}",
                f"{n:d}", f"{T_mean:.2f}", f"{T_pm:.2f}",
                f"{train_loss_all:.6f}", f"{val_loss_1k:.6f}", f"{test_loss_4k:.6f}",
            ]
            for s in seeds:
                row += [f"{perT2[s]:.6f}", f"{perG2[s]:.6f}", f"{perB3[s]:.6f}",
                        f"{perUold[s]:.6f}", f"{T_sizes[s]:d}"]
            append_row(out_csv, row)
            print(f"  OLD  '{split_name}': Bound3_mean={B3_mean:.4f} ±{B3_pm:.4f} "
                  f"(term2={T2_mean:.4f} ±{T2_pm:.4f}, g2={G2_mean:.4f} ±{G2_pm:.4f}), "
                  f"n={n}, |T|~{T_mean:.1f} ±{T_pm:.1f}")

        processed.append((model, mdir))

    # ---------- Aggregation across models ----------
    def _aggregate(split: str, out_csv: str):
        rows = []
        for model, mdir in processed:
            metrics_path = os.path.join(mdir, "metrics.csv")
            bound_path   = os.path.join(mdir, f"bound_{split}_miou.csv")

            if not (os.path.exists(metrics_path) and os.path.exists(bound_path)):
                print(f"[aggregate:{split}] Skipping {model} (missing metrics/bound CSV)")
                continue

            mrow = _read_last_row_as_dict(metrics_path)
            brow = _read_last_row_as_dict(bound_path)

            # Pull losses from metrics.csv
            train_loss_all = float(mrow["train_loss_all"])
            val_loss_1k    = float(mrow["val_loss_1k"])
            test_loss_4k   = float(mrow["test_loss_4k"])

            # Bound5 mean from bound CSV
            bound5_mean = float(brow["Bound5_mean"])

            # Average A2 and A3 across seeds from bound CSV
            a2_mean = _avg_seed_cols(brow, "A2_seed_")
            a3_mean = _avg_seed_cols(brow, "A3_seed_")

            rows.append([
                model,
                f"{train_loss_all:.6f}",
                f"{val_loss_1k:.6f}",
                f"{test_loss_4k:.6f}",
                f"{bound5_mean:.6f}",
                f"{a2_mean:.6f}",
                f"{a3_mean:.6f}",
            ])

        # Write aggregation
        with open(out_csv, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["model","train_loss_all","val_loss_1k","test_loss_4k","Bound5_mean","A2_mean","A3_mean"])
            w.writerows(rows)
        print(f"[aggregate:{split}] Wrote {out_csv} ({len(rows)} models)")

    _aggregate("train", "aggregate_train_summary.csv")
    _aggregate("val",   "aggregate_val_summary.csv")

    print("\nDone. Wrote per-model CSVs and aggregate summaries.")


if __name__ == "__main__":
    main()
