import os, math, json, csv, random
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset, Subset
import matplotlib.pyplot as plt


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


def get_or_make_indices(
    ds_len: int,
    train_subset_len: int,
    val_subset_len: int,
    seed: int,
    index_dir: str,
    train_idx_name: str = "cifar10_train20k_idx.npy",
    val_idx_name: str =  "cifar10_val10k_idx.npy",
    holdout_idx_name: str = "cifar10_holdout20k_idx.npy",
):
    """
    Returns (train_idx_np, holdout_idx_np).
    If index files exist, loads them.
    Otherwise, creates a deterministic split, saves both, and returns them.
    """
    ensure_dir(index_dir)
    train_path = os.path.join(index_dir, train_idx_name)
    val_path = os.path.join(index_dir, val_idx_name)
    holdout_path = os.path.join(index_dir, holdout_idx_name)

    if os.path.exists(train_path) and os.path.exists(val_path) and os.path.exists(holdout_path):
        train_idx = np.load(train_path)
        val_idx = np.load(val_path)
        holdout_idx = np.load(holdout_path)
        # Basic sanity checks
        assert len(train_idx) == train_subset_len, "Saved train index length mismatch."
        assert len(val_idx) == val_subset_len, "Saved val index length mismatch."
        assert len(holdout_idx) == ds_len - train_subset_len - val_subset_len, "Saved holdout index length mismatch."
        return train_idx, val_idx, holdout_idx

    # Create deterministic permutation
    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(ds_len, generator=g)
    train_idx_t = perm[:train_subset_len]
    val_idx_t = perm[train_subset_len:(train_subset_len +  val_subset_len)]
    holdout_idx_t = perm[(train_subset_len +  val_subset_len):]

    train_idx = train_idx_t.cpu().numpy()
    val_idx =  val_idx_t.cpu().numpy()
    holdout_idx = holdout_idx_t.cpu().numpy()

    np.save(train_path, train_idx)
    np.save(val_path, val_idx)
    np.save(holdout_path, holdout_idx)
    return train_idx, val_idx, holdout_idx


# ---------------------- hyper/paths (overfit) ----------------------
SEED = 1337
DATA_DIR = "./data"
INDEX_DIR = "./splits"

NUM_EPOCHS = 150
BATCH_SIZE = 32
LR = 1e-4
WEIGHT_DECAY = 0.0
OPTIMIZER = "adam"
SCHEDULER = "none"

NUM_WORKERS = 4
PIN_MEMORY = True
LOG_INTERVAL = 50

SUBSET_TRAIN_SIZE = 20_000
SUBSET_VAL_SIZE   = 10_000

# tiny train selection
TRAIN_PER_CLASS = 100
NUM_CLASSES = 10

SAVE_BEST = True
BEST_CKPT_PATH = "./cifar10_best_overfit.pt"
LAST_CKPT_PATH = "./cifar10_last_overfit.pt"
USE_AMP = True

# regularization knobs (disabled for overfitting)
LABEL_SMOOTH = 0.0
CLIP_MAX_NORM = 0.0

# theorem constants
DELTA  = 0.01
DELTA2 = 0.01   # for Unc_TP
ALPHA = 100.0
GAMMA_THEO = 0.03 ** (-1.0 / ALPHA)
K = 200
SEEDS: Tuple[int, ...] = (42, 43, 44, 45, 46)

# prepared assets location
GROUPS_DIR = os.environ.get("GROUPS_DIR", "archive/dgm-fashionmnist")

# outputs
METRICS_CSV = "metrics.csv"
BOUND_TD25 = {("train","01"):"bound_train_01.csv",
              ("train","ce"):"bound_train_ce.csv",
              ("val","01")  :"bound_val_01.csv",
              ("val","ce")  :"bound_val_ce.csv"}
BOUND_OLD  = {("train","01"):"bound_train_01_old.csv",
              ("train","ce"):"bound_train_ce_old.csv",
              ("val","01")  :"bound_val_01_old.csv",
              ("val","ce")  :"bound_val_ce_old.csv"}

LOSS_PLOT = "loss_plot_overfit.png"; ACC_PLOT = "acc_plot_overfit.png"

# correctness logs
CORR_DIR = "epoch_correctness"
os.makedirs(CORR_DIR, exist_ok=True)

# CIFAR-10 normalization
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD  = (0.2023, 0.1994, 0.2010)

# ---------------------- utils ----------------------
def seed_all(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Make initialization deterministic
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_device():
    if torch.cuda.is_available(): return torch.device("cuda")
    if hasattr(torch.backends,"mps") and torch.backends.mps.is_available(): return torch.device("mps")
    return torch.device("cpu")

def non_blocking_ok(device): return device.type == "cuda"
def pin_memory_ok(device): return PIN_MEMORY and device.type == "cuda"

# ---------------------- model ----------------------
# ===== RESNET 18 =====
class SmallCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.net = models.resnet18(weights=None, num_classes=num_classes)
    def forward(self,x): return self.net(x)

# ===== DENSENET 121 =====
class SmallCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        net = models.densenet121(weights=None)

        # CIFAR-10 is RGB (3-channel), tiny images (32x32):
        # use a lighter, non-strided stem and remove early downsampling
        net.features.conv0 = nn.Conv2d(
            in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False
        )
        net.features.pool0 = nn.Identity()              # remove first max-pool
        net.features.transition3.pool = nn.Identity()   # keep spatial size larger

        # replace classifier for 10 classes
        in_feats = net.classifier.in_features
        net.classifier = nn.Linear(in_feats, num_classes)

        self.net = net

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

def remove_batchnorm(model: nn.Module):
    for name, module in model.named_children():
        if isinstance(module, nn.BatchNorm2d):
            setattr(model, name, nn.Identity())
        else:
            remove_batchnorm(module)
    return model


# ---------------------- data ----------------------
def build_transforms():
    eval_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return eval_tf

def load_indices_or_make(n_full:int):
    p = {"train":os.path.join(GROUPS_DIR,"cifar10_train20k_idx.npy"),
         "val"  :os.path.join(GROUPS_DIR,"cifar10_val10k_idx.npy"),
         "hold" :os.path.join(GROUPS_DIR,"cifar10_holdout20k_idx.npy")}
    if all(os.path.exists(v) for v in p.values()):
        return np.load(p["train"]), np.load(p["val"]), np.load(p["hold"])
    return get_or_make_indices(
        ds_len=n_full,
        train_subset_len=SUBSET_TRAIN_SIZE,
        val_subset_len=SUBSET_VAL_SIZE,
        seed=SEED,
        index_dir=INDEX_DIR,
        train_idx_name="cifar10_train20k_idx.npy",
        val_idx_name="cifar10_val10k_idx.npy",
        holdout_idx_name="cifar10_holdout20k_idx.npy",
    )

def pick_balanced_subset_100(base_train: datasets.CIFAR10, candidate_ids: List[int],
                             per_class: int = TRAIN_PER_CLASS, seed: int = SEED) -> List[int]:
    labels = np.array(base_train.targets, dtype=int)
    by_class = [[] for _ in range(NUM_CLASSES)]
    for idx in candidate_ids:
        y = int(labels[idx])
        if 0 <= y < NUM_CLASSES:
            by_class[y].append(idx)

    rng = random.Random(seed)
    selected, leftovers = [], []
    for c in range(NUM_CLASSES):
        ids_c = by_class[c]
        rng.shuffle(ids_c)
        take = ids_c[:per_class]
        selected.extend(take)
        leftovers.extend(ids_c[per_class:])
    if len(selected) < NUM_CLASSES * per_class:
        need = NUM_CLASSES * per_class - len(selected)
        rng.shuffle(leftovers)
        selected.extend(leftovers[:need])
    return sorted(selected)

def build_dataloaders(device):
    tf = build_transforms()
    base_train = datasets.CIFAR10(DATA_DIR, train=True,  download=True, transform=tf)   # 50k
    test_set   = datasets.CIFAR10(DATA_DIR, train=False, download=True, transform=tf)   # 10k

    train_idx_np, val_idx_np, holdout_idx_np = load_indices_or_make(len(base_train))
    train_ids_full, val_ids = train_idx_np.tolist(), val_idx_np.tolist()

    train_ids_100 = pick_balanced_subset_100(base_train, train_ids_full, per_class=TRAIN_PER_CLASS, seed=SEED)

    train_ds   = Subset(base_train, train_ids_100)
    train_eval_ds = Subset(base_train, train_ids_100)
    val_ds     = Subset(base_train, val_ids)
    holdout_ds = Subset(base_train, holdout_idx_np.tolist())
    pm = pin_memory_ok(device)

    train_loader      = DataLoader(train_ds,      batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    train_eval_loader = DataLoader(train_eval_ds, batch_size=256,        shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    val_loader        = DataLoader(val_ds,        batch_size=256,        shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    hold_loader       = DataLoader(holdout_ds,    batch_size=256,        shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    test_loader       = DataLoader(test_set,      batch_size=256,        shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)

    # additional loaders for correctness logging (full datasets in order)
    full_train_loader = DataLoader(base_train, batch_size=512, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    full_test_loader  = DataLoader(test_set,   batch_size=512, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)

    return (base_train, np.array(train_ids_100, dtype=np.int64), val_idx_np, holdout_idx_np,
            train_loader, train_eval_loader, val_loader, hold_loader, test_loader,
            full_train_loader, full_test_loader)

class IndexedSubset(Dataset):
    def __init__(self, base_ds, indices:List[int]): self.base_ds=base_ds; self.indices=list(indices)
    def __len__(self): return len(self.indices)
    def __getitem__(self,i):
        idx=self.indices[i]; x,y=self.base_ds[idx]; return x,y,idx

# ---------------------- eval helpers ----------------------
@torch.no_grad()
def evaluate_ce_acc(model, loader, device, nb, amp_enabled):
    model.eval(); totL=0.0; corr=0; n=0
    with torch.amp.autocast(device_type=device.type, enabled=amp_enabled and device.type=="cuda"):
        for x,y in loader:
            x=x.to(device,non_blocking=nb); y=y.to(device,non_blocking=nb)
            logits=model(x); loss=F.cross_entropy(logits,y,reduction="sum")
            totL+=loss.item(); corr+=(logits.argmax(1)==y).sum().item(); n+=y.size(0)
    return totL/n, corr/n

def per_index_both_losses(model, device, base_ds, indices, batch_size=1024):
    ds=IndexedSubset(base_ds, indices); loader=DataLoader(ds,batch_size,False)
    ce_out:Dict[int,float]={}; z1_out:Dict[int,float]={}
    model.eval()
    with torch.no_grad():
        for x,y,idxs in loader:
            x=x.to(device); y=y.to(device); logits=model(x)
            ce=F.cross_entropy(logits,y,reduction="none").cpu().numpy()
            z1=(logits.argmax(1)!=y).float().cpu().numpy()
            idx_np=idxs.numpy()
            for j in range(len(idx_np)):
                ce_out[int(idx_np[j])]=float(ce[j]); z1_out[int(idx_np[j])]=float(z1[j])
    return ce_out, z1_out

def compute_C_holdout(model, device, hold_loader, loss_kind:str):
    if loss_kind=="01": return 1.0
    model.eval(); m=0.0
    with torch.no_grad():
        for x,y in hold_loader:
            x=x.to(device); y=y.to(device)
            ce=F.cross_entropy(model(x),y,reduction="none")
            m=max(m, float(ce.max().item()))
    return m

@torch.no_grad()
def correctness_mask_from_loader(model, loader, device, nb, amp_enabled):
    """Return a boolean vector in dataset order representing correctness."""
    model.eval()
    outs = []
    with torch.amp.autocast(device_type=device.type, enabled=amp_enabled and device.type=="cuda"):
        for x, y in loader:
            x = x.to(device, non_blocking=nb)
            y = y.to(device, non_blocking=nb)
            logits = model(x)
            preds = logits.argmax(1)
            outs.append((preds == y).detach().cpu().numpy().astype(np.uint8))
    return np.concatenate(outs, axis=0)  # uint8 {0,1}

# ---------------------- bounds & Unc_TP ----------------------
def compute_b(n:int, delta:float)->float:
    return math.sqrt(0.5*n*math.log(2.0/delta))

# TD25 '(5)' û — uses |T| (non-empty), NOT K
def compute_uhat_td25(n:int, n_i_list:List[int], b:float, gamma:float, delta:float, T_size:int)->float:
    term1 = gamma*(1.0+2.0*b)/(2.0*n)
    term2 = (gamma*T_size*(b**2))/(2.0*(n**2))
    frac_sq_sum = sum((ni/n)**2 for ni in n_i_list)
    term3 = (gamma**2)/2.0 * frac_sq_sum
    term4 = (gamma**2) * math.sqrt(math.log(2.0/delta)/(2.0*n))
    return term1+term2+term3+term4

def compute_A3_mod(C:float, a_o:float, uhat:float, alpha:float, gamma:float, delta:float, n:int)->float:
    """Modified A3: C*sqrt(û α ln γ) + a_o*sqrt(ln(2/δ)/(2n))."""
    return ( C * math.sqrt(max(0.0, uhat * alpha * math.log(gamma))) +
             a_o * math.sqrt(math.log(2.0/delta)/(2.0*n)) )

# Old '(3)' û
def compute_uhat_old(n:int, n_i_list:List[int], gamma:float, delta:float, K:int)->float:
    frac_sq_sum = sum((ni/n)**2 for ni in n_i_list)
    return (gamma/(2.0*n)) + (gamma**2)/2.0*frac_sq_sum + (gamma**2)*math.sqrt((2.0/n)*math.log(2.0*K/delta))

# Old '(3)' g2(δ) — uses |T| and {n_i}
def compute_g2_old(C:float, n:int, K:int, delta:float, T_size:int, n_i_list:List[int])->float:
    ln = math.log(2.0*K/delta)
    sum_sqrt = sum(math.sqrt(ni) for ni in n_i_list if ni>0)
    termA = C*(1.0+math.sqrt(2.0)) * math.sqrt(ln/n) * (sum_sqrt / n)
    termB = (4.0 * C * T_size * ln) / n
    return termA + termB

# ---- Unc_TP (delta2 used here) ----
def compute_a_hat_tp(n:int, n_i_list:List[int], gamma:float, delta2:float, K:int)->float:
    frac_sq_sum = sum((ni/n)**2 for ni in n_i_list)
    return (gamma/(2.0*n)) + (gamma**2)/2.0*frac_sq_sum + (gamma**2)*math.sqrt( (2.0/n) * math.log(2.0*K/delta2) )

def compute_unc_tp(C:float, alpha:float, gamma:float, delta2:float, n:int,
                   n_i_list:List[int], a_i_list:List[float], a_o:float)->float:
    # T = non-empty clusters (aligned with n_i_list > 0)
    T_idx = [i for i,ni in enumerate(n_i_list) if ni>0]
    n_i_T = [n_i_list[i] for i in T_idx]
    a_T   = [a_i_list[i] for i in T_idx]

    a_hat = compute_a_hat_tp(n, n_i_list, gamma, delta2, K)
    ln_term = math.log(2.0*K/delta2)

    term1 = C * math.sqrt( max(0.0, a_hat * alpha * math.log(gamma)) )
    term2 = math.sqrt( ln_term / n ) * sum( math.sqrt(ni) * (a_o + math.sqrt(2.0*ai)) for ni,ai in zip(n_i_T, a_T) )
    term3 = (2.0 * ln_term / n) * ( a_o * len(T_idx) + sum(a_T) )
    return term1 + term2 + term3

def load_grouping(which:str, seed:int)->Dict[str,List[int]]:
    f=os.path.join(GROUPS_DIR, f"cifar10_{which}_grouping_K_{K}_seed_{seed}.json")
    if not os.path.exists(f): raise FileNotFoundError(f"Grouping JSON not found: {f}")
    with open(f,"r") as fh: return json.load(fh)

# ---------------------- CSV helpers ----------------------
def init_basic_metrics_csv(path):
    if os.path.exists(path): return
    with open(path,"w",newline="") as f:
        csv.writer(f).writerow(["epoch","train_loss","train_acc","val_loss","val_acc","test_loss","test_acc","lr"])

def init_bound_td25_csv(path, seeds):
    if os.path.exists(path): return
    header=["epoch","bound_set","loss_type","A1","C","b","gamma","alpha","delta",
            "Bound5_mean","Unc_mean","UncTP_mean",
            "train_loss","train_acc","val_loss","val_acc","test_loss","test_acc","lr"]
    for s in seeds:
        header += [f"A2_seed_{s}", f"A3_seed_{s}", f"Bound5_seed_{s}", f"Unc_seed_{s}",
                   f"uhat_seed_{s}", f"T_size_seed_{s}", f"UncTP_seed_{s}"]
    with open(path,"w",newline="") as f: csv.writer(f).writerow(header)

def init_bound_old_csv(path, seeds):
    if os.path.exists(path): return
    header=["epoch","bound_set","loss_type","A1","C","gamma","alpha","delta",
            "Bound3_mean","Term2_mean","g2_mean","UncTP_mean",
            "train_loss","train_acc","val_loss","val_acc","test_loss","test_acc","lr"]
    for s in seeds:
        header += [f"term2_seed_{s}", f"g2_seed_{s}", f"Bound3_seed_{s}",
                   f"uhat_old_seed_{s}", f"T_size_seed_{s}", f"UncTP_seed_{s}"]
    with open(path,"w",newline="") as f: csv.writer(f).writerow(header)

def append_row(path, row):
    with open(path,"a",newline="") as f: csv.writer(f).writerow(row)

# ---------------------- train one epoch ----------------------
def train_one_epoch(model, loader, optimizer, device, nb, scaler, amp_enabled):
    model.train(); totL=0.0; corr=0; n=0
    for b,(x,y) in enumerate(loader, start=1):
        x=x.to(device,non_blocking=nb); y=y.to(device,non_blocking=nb)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type=device.type, enabled=amp_enabled and device.type=="cuda"):
            logits=model(x)
            loss=F.cross_entropy(logits,y,label_smoothing=LABEL_SMOOTH)
        if device.type=="cuda" and amp_enabled:
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward(); optimizer.step()
        with torch.no_grad():
            totL += float(loss.item())*y.size(0); corr += (logits.argmax(1)==y).sum().item(); n += y.size(0)
        if LOG_INTERVAL and b%LOG_INTERVAL==0: print(f"  [Train] Batch {b:4d} | Loss {loss.item():.4f}")
    return totL/n, corr/n

# ---------------------- main ----------------------
def main():
    seed_all(SEED); device=get_device(); nb=non_blocking_ok(device); amp=(device.type=="cuda") and USE_AMP
    print(f"Using device: {device} | AMP: {amp}")

    (full_train, train_idx_np, val_idx_np, holdout_idx_np,
     train_loader, train_eval_loader, val_loader, hold_loader, test_loader,
     full_train_loader, full_test_loader) = build_dataloaders(device)

    train_ids, val_ids = train_idx_np.tolist(), val_idx_np.tolist()
    n_train, n_val = len(train_ids), len(val_ids)
    assert n_train == NUM_CLASSES * TRAIN_PER_CLASS, f"Expected {NUM_CLASSES*TRAIN_PER_CLASS} train, got {n_train}"

    # load groupings + precompute n_i, |T|
    train_id_set = set(train_ids)
    groupings = {"train":{}, "val":{}}
    n_i_lists = {"train":{}, "val":{}}
    T_sizes   = {"train":{}, "val":{}}
    for s in SEEDS:
        gtr=load_grouping("train",s); gva=load_grouping("val",s)
        groupings["train"][s]=gtr; groupings["val"][s]=gva

        all_tr = sorted([idx for k in range(K) for idx in gtr[str(k)]])
        all_va = sorted([idx for k in range(K) for idx in gva[str(k)]])
        assert set(train_ids).issubset(set(all_tr)), f"Tiny train not subset of grouping seed {s}."
        assert all_va == sorted(val_ids),            f"Val grouping seed {s} not covering val10k exactly."

        n_i_lists["train"][s]=[len([idx for idx in gtr[str(k)] if idx in train_id_set]) for k in range(K)]
        n_i_lists["val"][s]  =[len(gva[str(k)]) for k in range(K)]
        T_sizes["train"][s]=sum(1 for ni in n_i_lists["train"][s] if ni>0)
        T_sizes["val"][s]  =sum(1 for ni in n_i_lists["val"][s]   if ni>0)

    b_vals = {"train": compute_b(n_train, DELTA), "val": compute_b(n_val, DELTA)}
    uhat_td25 = {
        "train": {s: compute_uhat_td25(n_train, n_i_lists["train"][s], b_vals["train"], GAMMA_THEO, DELTA, T_sizes["train"][s]) for s in SEEDS},
        "val":   {s: compute_uhat_td25(n_val,   n_i_lists["val"][s],   b_vals["val"],   GAMMA_THEO, DELTA, T_sizes["val"][s])   for s in SEEDS},
    }

    # opt (no scheduler)
    model = SmallCNN(NUM_CLASSES).to(device)
    model.net = remove_batchnorm(model.net)
    if OPTIMIZER.lower()!="adam": raise ValueError("Fixed optimizer is Adam.")
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = torch.amp.GradScaler("cuda", enabled=amp)

    # init CSVs
    init_basic_metrics_csv(METRICS_CSV)
    for _,p in BOUND_TD25.items(): init_bound_td25_csv(p, SEEDS)
    for _,p in BOUND_OLD.items():  init_bound_old_csv(p, SEEDS)

    best_test_acc=0.0
    hist={"train_loss":[], "train_acc":[], "test_loss":[], "test_acc":[]}
    all_epoch_correct = []  # will hold (60000,) uint8 per epoch

    for epoch in range(1, NUM_EPOCHS+1):
        print(f"\n=== Epoch {epoch}/{NUM_EPOCHS} ===")
        _trainL,_trainA = train_one_epoch(model, train_loader, optimizer, device, nb, scaler, amp)
        train_ce, train_acc = evaluate_ce_acc(model, train_eval_loader, device, nb, amp)
        val_ce,   val_acc   = evaluate_ce_acc(model, val_loader, device, nb, amp)
        test_ce,  test_acc  = evaluate_ce_acc(model, test_loader, device, nb, amp)
        lr = optimizer.param_groups[0]["lr"]

        append_row(METRICS_CSV, [epoch, f"{train_ce:.6f}", f"{train_acc:.6f}",
                                 f"{val_ce:.6f}", f"{val_acc:.6f}", f"{test_ce:.6f}", f"{test_acc:.6f}", f"{lr:.8f}"])
        hist["train_loss"].append(train_ce); hist["train_acc"].append(train_acc)
        hist["test_loss"].append(test_ce);  hist["test_acc"].append(test_acc)

        print(f"[Metrics] train_ce={train_ce:.4f} | train_acc={train_acc*100:.2f}% | "
              f"val_ce={val_ce:.4f} | val_acc={val_acc*100:.2f}% | "
              f"test_ce={test_ce:.4f} | test_acc={test_acc*100:.2f}% | lr={lr:.6f}")

        torch.save({"model_state":model.state_dict(),"test_acc":test_acc,"epoch":epoch}, LAST_CKPT_PATH)
        if SAVE_BEST and test_acc>best_test_acc:
            best_test_acc=test_acc
            torch.save({"model_state":model.state_dict(),"test_acc":best_test_acc,"epoch":epoch}, BEST_CKPT_PATH)
            print(f"  ↳ Saved new best to {BEST_CKPT_PATH} (test_acc={best_test_acc*100:.2f}%)")

        # per-index losses
        ce_tr, z1_tr = per_index_both_losses(model, device, full_train, train_ids, batch_size=1024)
        ce_va, z1_va = per_index_both_losses(model, device, full_train, val_ids,   batch_size=1024)

        A1 = {("train","01"): float(np.mean(list(z1_tr.values()))),
              ("train","ce"): float(np.mean(list(ce_tr.values()))),
              ("val","01")  : float(np.mean(list(z1_va.values()))),
              ("val","ce")  : float(np.mean(list(ce_va.values())))}

        C_vals = {"01":1.0, "ce":compute_C_holdout(model, device, hold_loader, "ce")}

        # ===== precompute Unc_TP ingredients & a_o on validation clusters per seed =====
        def build_a_i_and_counts(seed:int, loss_kind:str):
            g_val = groupings["val"][seed]
            per_idx = z1_va if loss_kind == "01" else ce_va
            n_i = []
            a_i = []
            for k in range(K):
                ids_k = g_val[str(k)]
                n_k = len(ids_k)
                n_i.append(n_k)
                if n_k > 0:
                    vals = [per_idx[idx] for idx in ids_k]
                    a_i.append(float(np.mean(vals)))
                else:
                    a_i.append(0.0)
            return n_i, a_i

        uncTP_val = {loss: {} for loss in ["01","ce"]}
        a_o_val   = {loss: {} for loss in ["01","ce"]}
        for loss_kind in ["01","ce"]:
            for s in SEEDS:
                n_i_val, a_i_val = build_a_i_and_counts(s, loss_kind)
                # a_o from validation:
                nonempty_mask = [ni > 0 for ni in n_i_val]
                if all(nonempty_mask):  # |T| = K
                    a_o = 0.0
                else:
                    n_total = float(n_val)
                    pk = [ni / n_total for ni in n_i_val]  # p_k = n_k / n
                    numer = sum(p * a for p, a, m in zip(pk, a_i_val, nonempty_mask) if not m)
                    denom = sum(p for p, m in zip(pk, nonempty_mask) if not m)
                    a_o = 0.0 if denom == 0.0 else numer / denom
                a_o_val[loss_kind][s] = a_o

                uncTP_val[loss_kind][s] = compute_unc_tp(
                    C=C_vals[loss_kind], alpha=ALPHA, gamma=GAMMA_THEO, delta2=DELTA2,
                    n=n_val, n_i_list=n_i_val, a_i_list=a_i_val, a_o=a_o
                )

        # ---- TD25 '(5)' rows (modified final term uses a_o) ----
        for (whichS, loss_kind), out_csv in BOUND_TD25.items():
            ids = train_ids if whichS=="train" else val_ids
            n = len(ids)
            b = b_vals[whichS]
            uhat_by_seed = uhat_td25[whichS]
            grouping_for_S = groupings[whichS]
            per_idx = (z1_tr if whichS=="train" and loss_kind=="01" else
                       ce_tr if whichS=="train" else
                       z1_va if loss_kind=="01" else ce_va)
            A1_val = A1[(whichS, loss_kind)]
            C_used = C_vals[loss_kind]

            A2_s, A3_s, B5_s, Unc_s = {}, {}, {}, {}
            for s in SEEDS:
                g = grouping_for_S[s]
                F_Si=[]
                id_set = set(ids)
                for k in range(K):
                    ids_k=[idx for idx in g[str(k)] if idx in id_set]
                    if len(ids_k)==0: continue
                    vals=[per_idx[idx] for idx in ids_k]
                    F_Si.append(float(np.mean(vals)))
                A2 = (b/n) * float(np.sum(F_Si))
                # modified A3 with a_o from VAL:
                A3 = compute_A3_mod(C_used, a_o_val[loss_kind][s], uhat_by_seed[s], ALPHA, GAMMA_THEO, DELTA, n)
                B5 = A1_val + A2 + A3
                Unc = A2 + A3
                A2_s[s]=A2; A3_s[s]=A3; B5_s[s]=B5; Unc_s[s]=Unc

            unc_tp_mean = float(np.mean([uncTP_val[loss_kind][s] for s in SEEDS]))

            row=[epoch, whichS, loss_kind, f"{A1_val:.6f}", f"{C_used:.6f}", f"{b:.6f}",
                 f"{GAMMA_THEO:.8f}", f"{ALPHA:.6f}", f"{DELTA:.6f}",
                 f"{np.mean(list(B5_s.values())):.6f}", f"{np.mean(list(Unc_s.values())):.6f}",
                 f"{unc_tp_mean:.6f}",
                 f"{train_ce:.6f}", f"{train_acc:.6f}",
                 f"{val_ce:.6f}", f"{val_acc:.6f}",
                 f"{test_ce:.6f}", f"{test_acc:.6f}", f"{lr:.8f}"]
            for s in SEEDS:
                row += [f"{A2_s[s]:.6f}", f"{A3_s[s]:.6f}", f"{B5_s[s]:.6f}", f"{Unc_s[s]:.6f}",
                        f"{uhat_by_seed[s]:.6f}", f"{T_sizes[whichS][s]:d}",
                        f"{uncTP_val[loss_kind][s]:.6f}"]
            append_row(out_csv, row)

        # ---- OLD '(3)' rows ----
        for (whichS, loss_kind), out_csv in BOUND_OLD.items():
            ids = train_ids if whichS=="train" else val_ids
            n = len(ids)
            n_i_map = n_i_lists[whichS]
            Tsize_map = T_sizes[whichS]
            A1_val = A1[(whichS, loss_kind)]
            C_used = C_vals[loss_kind]

            term2_s, g2_s, B3_s, uhat_old_s = {}, {}, {}, {}
            for s in SEEDS:
                n_i = n_i_map[s]
                u_old = compute_uhat_old(n, n_i, GAMMA_THEO, DELTA, K)
                g2 = compute_g2_old(C_used, n, K, DELTA, Tsize_map[s], n_i)
                term2 = C_used * math.sqrt(max(0.0, u_old * ALPHA * math.log(GAMMA_THEO)))
                B3 = A1_val + term2 + g2
                uhat_old_s[s]=u_old; g2_s[s]=g2; term2_s[s]=term2; B3_s[s]=B3

            unc_tp_mean = float(np.mean([uncTP_val[loss_kind][s] for s in SEEDS]))

            row=[epoch, whichS, loss_kind,
                 f"{A1_val:.6f}", f"{C_used:.6f}",
                 f"{GAMMA_THEO:.8f}", f"{ALPHA:.6f}", f"{DELTA:.6f}",
                 f"{np.mean(list(B3_s.values())):.6f}",
                 f"{np.mean(list(term2_s.values())):.6f}",
                 f"{np.mean(list(g2_s.values())):.6f}",
                 f"{unc_tp_mean:.6f}",
                 f"{train_ce:.6f}", f"{train_acc:.6f}",
                 f"{val_ce:.6f}", f"{val_acc:.6f}",
                 f"{test_ce:.6f}", f"{test_acc:.6f}", f"{lr:.8f}"]
            for s in SEEDS:
                row += [f"{term2_s[s]:.6f}", f"{g2_s[s]:.6f}", f"{B3_s[s]:.6f}",
                        f"{uhat_old_s[s]:.6f}", f"{T_sizes[whichS][s]:d}",
                        f"{uncTP_val[loss_kind][s]:.6f}"]
            append_row(out_csv, row)

        # ---------- correctness logging (50k train + 10k test) ----------
        corr_train = correctness_mask_from_loader(model, full_train_loader, device, nb, amp)  # (50000,)
        corr_test  = correctness_mask_from_loader(model, full_test_loader,  device, nb, amp)  # (10000,)
        corr_all   = np.concatenate([corr_train, corr_test], axis=0).astype(np.uint8)        # (60000,)
        all_epoch_correct.append(corr_all)
        np.save(os.path.join(CORR_DIR, f"correctness_epoch_{epoch:03d}.npy"), corr_all)

    # plots
    epochs=range(1, NUM_EPOCHS+1)
    plt.figure(figsize=(8,5))
    plt.plot(epochs,hist["train_loss"],label="Train CE (eval)")
    plt.plot(epochs,hist["test_loss"],label="Test CE")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("CIFAR-10 Overfit: Train vs Test Loss")
    plt.legend(); plt.grid(True); plt.tight_layout(); plt.savefig(LOSS_PLOT,dpi=150)

    plt.figure(figsize=(8,5))
    plt.plot(epochs,[a*100 for a in hist["train_acc"]],label="Train Acc (eval)")
    plt.plot(epochs,[a*100 for a in hist["test_acc"]],label="Test Acc")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("CIFAR-10 Overfit: Train vs Test Accuracy")
    plt.legend(); plt.grid(True); plt.tight_layout(); plt.savefig(ACC_PLOT,dpi=150); plt.show()

    # save stacked correctness tensor (NUM_EPOCHS x 60000)
    all_epoch_correct_np = np.stack(all_epoch_correct, axis=0)
    np.save(os.path.join(CORR_DIR, "correctness_all_60k_epochs.npy"), all_epoch_correct_np)

    if device.type=="cuda": torch.cuda.empty_cache()
    elif device.type=="mps":
        try: torch.mps.empty_cache()
        except Exception: pass

    print("\nTraining complete. Wrote:")
    print(f"  - {METRICS_CSV}")
    for _,f in BOUND_TD25.items(): print(f"  - {f}")
    for _,f in BOUND_OLD.items():  print(f"  - {f}")
    print(f"  - Per-epoch correctness: {CORR_DIR}/correctness_epoch_###.npy")
    print(f"  - Stacked correctness:   {CORR_DIR}/correctness_all_60k_epochs.npy")
    if os.path.exists(BEST_CKPT_PATH): print(f"Best checkpoint: {BEST_CKPT_PATH}")

if __name__ == "__main__":
    main()
