import os, math, json, csv, random
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset, Subset
import matplotlib.pyplot as plt


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


def get_or_make_indices(
    ds_len: int,
    train_subset_len: int,
    val_subset_len: int,
    seed: int,
    index_dir: str,
    train_idx_name: str = "fashionmnist_train20k_idx.npy",
    val_idx_name: str =  "fashionmnist_val10k_idx.npy",
    holdout_idx_name: str = "fashionmnist_holdout30k_idx.npy",
):
    """
    Returns (train_idx_np, holdout_idx_np).
    If index files exist, loads them.
    Otherwise, creates a deterministic split, saves both, and returns them.
    """
    ensure_dir(index_dir)
    train_path = os.path.join(index_dir, train_idx_name)
    val_path = os.path.join(index_dir, val_idx_name)
    holdout_path = os.path.join(index_dir, holdout_idx_name)

    if os.path.exists(train_path) and os.path.exists(val_path) and os.path.exists(holdout_path):
        train_idx = np.load(train_path)
        val_idx = np.load(val_path)
        holdout_idx = np.load(holdout_path)
        # Basic sanity checks
        assert len(train_idx) == train_subset_len, "Saved train index length mismatch."
        assert len(val_idx) == val_subset_len, "Saved val index length mismatch."
        assert len(holdout_idx) == ds_len - train_subset_len - val_subset_len, "Saved holdout index length mismatch."
        return train_idx, val_idx, holdout_idx

    # Create deterministic permutation
    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(ds_len, generator=g)
    train_idx_t = perm[:train_subset_len]
    val_idx_t = perm[train_subset_len:(train_subset_len +  val_subset_len)]
    holdout_idx_t = perm[(train_subset_len +  val_subset_len):]

    train_idx = train_idx_t.cpu().numpy()
    val_idx =  val_idx_t.cpu().numpy()
    holdout_idx = holdout_idx_t.cpu().numpy()

    np.save(train_path, train_idx)
    np.save(val_path, val_idx)
    np.save(holdout_path, holdout_idx)
    return train_idx, val_idx, holdout_idx


# ---------------------- hyper/paths ----------------------
SEED = 1337
DATA_DIR = "./data"
INDEX_DIR = "./splits"

NUM_EPOCHS = 150
BATCH_SIZE = 256
LR = 1e-3
WEIGHT_DECAY = 5e-4        # (slightly stronger)
OPTIMIZER = "adam"
SCHEDULER = "steplr"
STEP_SIZE = 5
GAMMA_STEP = 0.5

NUM_WORKERS = 4
PIN_MEMORY = True
LOG_INTERVAL = 50

SUBSET_TRAIN_SIZE = 20_000
SUBSET_VAL_SIZE   = 10_000

SAVE_BEST = True
BEST_CKPT_PATH = "./fashionmnist_best.pt"
LAST_CKPT_PATH = "./fashionmnist_last.pt"
USE_AMP = True

# regularization knobs
LABEL_SMOOTH = 0.1
CLIP_MAX_NORM = 1.0

# simple augs for 28×28 grayscale
AUG_RANDOM_CROP_PAD = 4
AUG_HFLIP_P = 0.5
AUG_ROT_DEG = 10
AUG_ERASE_P = 0.25

# theorem constants
DELTA = 0.01
ALPHA = 100.0
GAMMA_THEO = 0.03 ** (-1.0 / ALPHA)
K = 200
SEEDS: Tuple[int, ...] = (42, 43, 44, 45, 46)

# prepared assets location
GROUPS_DIR = os.environ.get("GROUPS_DIR", "archive/dgm-fashionmnist")

# outputs
METRICS_CSV = "metrics.csv"
BOUND_TD25 = {("train","01"):"bound_train_01.csv",
              ("train","ce"):"bound_train_ce.csv",
              ("val","01")  :"bound_val_01.csv",
              ("val","ce")  :"bound_val_ce.csv"}
BOUND_OLD  = {("train","01"):"bound_train_01_old.csv",
              ("train","ce"):"bound_train_ce_old.csv",
              ("val","01")  :"bound_val_01_old.csv",
              ("val","ce")  :"bound_val_ce_old.csv"}

LOSS_PLOT = "loss_plot.png"; ACC_PLOT = "acc_plot.png"

# FMNIST norm
FASHION_MEAN = 0.28604059698879553
FASHION_STD  = 0.35302424451492237
NUM_CLASSES = 10

# ---------------------- utils ----------------------
def seed_all(seed:int):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True; torch.backends.cudnn.deterministic = False

def get_device():
    if torch.cuda.is_available(): return torch.device("cuda")
    if hasattr(torch.backends,"mps") and torch.backends.mps.is_available(): return torch.device("mps")
    return torch.device("cpu")

def non_blocking_ok(device): return device.type == "cuda"
def pin_memory_ok(device): return PIN_MEMORY and device.type == "cuda"

# ---------------------- model ----------------------
# ===== RESNET 18 =====
class SmallCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.net = models.resnet18(weights=None, num_classes=num_classes)
        self.net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    def forward(self,x): return self.net(x)

# ===== DENSENET 121 =====
class SmallCNN(nn.Module):
    def __init__(self, num_classes: int = NUM_CLASSES):
        super().__init__()
        self.net = models.densenet121(weights=None)

        # 1-channel input; also make conv lighter & non-strided
        self.net.features.conv0 = nn.Conv2d(
            in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False
        )

        # remove the very early downsample
        self.net.features.pool0 = nn.Identity()

        # skip the final /2 in transitions
        self.net.features.transition3.pool = nn.Identity()

        # replace classifier
        in_feats = self.net.classifier.in_features
        self.net.classifier = nn.Linear(in_feats, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# ---------------------- data ----------------------
def build_transforms():
    train_tf = transforms.Compose([
        transforms.RandomCrop(28, padding=AUG_RANDOM_CROP_PAD, padding_mode="reflect"),
        transforms.RandomHorizontalFlip(AUG_HFLIP_P),
        transforms.RandomRotation(AUG_ROT_DEG, fill=0),
        transforms.ToTensor(),
        transforms.Normalize((FASHION_MEAN,), (FASHION_STD,)),
        transforms.RandomErasing(p=AUG_ERASE_P, scale=(0.02,0.15), ratio=(0.3,3.3)),
    ])
    eval_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((FASHION_MEAN,), (FASHION_STD,))
    ])
    return train_tf, eval_tf

def load_indices_or_make(n_full:int):
    p = {"train":os.path.join(GROUPS_DIR,"fashionmnist_train20k_idx.npy"),
         "val"  :os.path.join(GROUPS_DIR,"fashionmnist_val10k_idx.npy"),
         "hold" :os.path.join(GROUPS_DIR,"fashionmnist_holdout30k_idx.npy")}
    if all(os.path.exists(v) for v in p.values()):
        return np.load(p["train"]), np.load(p["val"]), np.load(p["hold"])
    return get_or_make_indices(n_full, SUBSET_TRAIN_SIZE, SUBSET_VAL_SIZE, SEED, INDEX_DIR,
                               "fashionmnist_train20k_idx.npy","fashionmnist_val10k_idx.npy",
                               "fashionmnist_holdout30k_idx.npy")

def build_dataloaders(device):
    tr_tf, ev_tf = build_transforms()
    base_train_aug  = datasets.FashionMNIST(DATA_DIR, train=True,  download=True, transform=tr_tf)
    base_train_eval = datasets.FashionMNIST(DATA_DIR, train=True,  download=True, transform=ev_tf)
    test_set        = datasets.FashionMNIST(DATA_DIR, train=False, download=True, transform=ev_tf)

    train_idx_np, val_idx_np, holdout_idx_np = load_indices_or_make(len(base_train_eval))
    train_ids, val_ids = train_idx_np.tolist(), val_idx_np.tolist()

    train_ds_aug  = Subset(base_train_aug,  train_ids)
    train_ds_eval = Subset(base_train_eval, train_ids)
    val_ds        = Subset(base_train_eval, val_ids)
    holdout_ds    = Subset(base_train_eval, holdout_idx_np.tolist())

    pm = pin_memory_ok(device)
    train_loader      = DataLoader(train_ds_aug,  batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    train_eval_loader = DataLoader(train_ds_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    val_loader        = DataLoader(val_ds,       batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    hold_loader       = DataLoader(holdout_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)
    test_loader       = DataLoader(test_set,     batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pm, drop_last=False)

    return (base_train_eval, train_idx_np, val_idx_np, holdout_idx_np,
            train_loader, train_eval_loader, val_loader, hold_loader, test_loader)

class IndexedSubset(Dataset):
    def __init__(self, base_ds, indices:List[int]): self.base_ds=base_ds; self.indices=list(indices)
    def __len__(self): return len(self.indices)
    def __getitem__(self,i):
        idx=self.indices[i]; x,y=self.base_ds[idx]; return x,y,idx

# ---------------------- eval helpers ----------------------
@torch.no_grad()
def evaluate_ce_acc(model, loader, device, nb, amp_enabled):
    model.eval(); totL=0.0; corr=0; n=0
    with torch.amp.autocast(device_type=device.type, enabled=amp_enabled and device.type=="cuda"):
        for x,y in loader:
            x=x.to(device,non_blocking=nb); y=y.to(device,non_blocking=nb)
            logits=model(x); loss=F.cross_entropy(logits,y,reduction="sum")
            totL+=loss.item(); corr+=(logits.argmax(1)==y).sum().item(); n+=y.size(0)
    return totL/n, corr/n

def per_index_both_losses(model, device, base_ds, indices, batch_size=1024):
    ds=IndexedSubset(base_ds, indices); loader=DataLoader(ds,batch_size,False)
    ce_out:Dict[int,float]={}; z1_out:Dict[int,float]={}
    model.eval()
    with torch.no_grad():
        for x,y,idxs in loader:
            x=x.to(device); y=y.to(device); logits=model(x)
            ce=F.cross_entropy(logits,y,reduction="none").cpu().numpy()
            z1=(logits.argmax(1)!=y).float().cpu().numpy()
            idx_np=idxs.numpy()
            for j in range(len(idx_np)):
                ce_out[int(idx_np[j])]=float(ce[j]); z1_out[int(idx_np[j])]=float(z1[j])
    return ce_out, z1_out

def compute_C_holdout(model, device, hold_loader, loss_kind:str):
    if loss_kind=="01": return 1.0
    model.eval(); m=0.0
    with torch.no_grad():
        for x,y in hold_loader:
            x=x.to(device); y=y.to(device)
            ce=F.cross_entropy(model(x),y,reduction="none")
            m=max(m, float(ce.max().item()))
    return m

# ---------------------- bounds ----------------------
def compute_b(n:int, delta:float)->float:
    return math.sqrt(0.5*n*math.log(2.0/delta))

# TD25 '(5)' û — uses |T| (non-empty), NOT K
def compute_uhat_td25(n:int, n_i_list:List[int], b:float, gamma:float, delta:float, T_size:int)->float:
    term1 = gamma*(1.0+2.0*b)/(2.0*n)
    term2 = (gamma*T_size*(b**2))/(2.0*(n**2))
    frac_sq_sum = sum((ni/n)**2 for ni in n_i_list)
    term3 = (gamma**2)/2.0 * frac_sq_sum
    term4 = (gamma**2) * math.sqrt(math.log(2.0/delta)/(2.0*n))
    return term1+term2+term3+term4

def compute_A3(C:float, uhat:float, alpha:float, gamma:float, delta:float, n:int)->float:
    return C*math.sqrt(max(0.0, uhat*alpha*math.log(gamma))) + C*math.sqrt(math.log(2.0/delta)/(2.0*n))

# Old '(3)' û
def compute_uhat_old(n:int, n_i_list:List[int], gamma:float, delta:float, K:int)->float:
    frac_sq_sum = sum((ni/n)**2 for ni in n_i_list)
    return (gamma/(2.0*n)) + (gamma**2)/2.0*frac_sq_sum + (gamma**2)*math.sqrt((2.0/n)*math.log(2.0*K/delta))

# Old '(3)' g2(δ)  — uses |T| and {n_i}. If your exact formula differs, tweak here.
def compute_g2_old(C:float, n:int, K:int, delta:float, T_size:int, n_i_list:List[int])->float:
    ln = math.log(2.0*K/delta)
    sum_sqrt = sum(math.sqrt(ni) for ni in n_i_list if ni>0)
    termA = C*(1.0+math.sqrt(2.0)) * math.sqrt(ln/n) * (sum_sqrt / n)
    termB = (4.0 * C * T_size * ln) / n
    return termA + termB

def load_grouping(which:str, seed:int)->Dict[str,List[int]]:
    f=os.path.join(GROUPS_DIR, f"fashionmnist_{which}_grouping_K_{K}_seed_{seed}.json")
    if not os.path.exists(f): raise FileNotFoundError(f"Grouping JSON not found: {f}")
    with open(f,"r") as fh: return json.load(fh)

# ---------------------- CSV helpers ----------------------
def init_basic_metrics_csv(path):
    if os.path.exists(path): return
    with open(path,"w",newline="") as f:
        csv.writer(f).writerow(["epoch","train_loss","train_acc","val_loss","val_acc","test_loss","test_acc","lr"])

def init_bound_td25_csv(path, seeds):
    if os.path.exists(path): return
    header=["epoch","bound_set","loss_type","A1","C","b","gamma","alpha","delta",
            "Bound5_mean","Unc_mean","train_loss","train_acc","val_loss","val_acc","test_loss","test_acc","lr"]
    for s in seeds:
        header += [f"A2_seed_{s}", f"A3_seed_{s}", f"Bound5_seed_{s}", f"Unc_seed_{s}", f"uhat_seed_{s}", f"T_size_seed_{s}"]
    with open(path,"w",newline="") as f: csv.writer(f).writerow(header)

def init_bound_old_csv(path, seeds):
    if os.path.exists(path): return
    header=["epoch","bound_set","loss_type","A1","C","gamma","alpha","delta",
            "Bound3_mean","Term2_mean","g2_mean",
            "train_loss","train_acc","val_loss","val_acc","test_loss","test_acc","lr"]
    for s in seeds:
        header += [f"term2_seed_{s}", f"g2_seed_{s}", f"Bound3_seed_{s}", f"uhat_old_seed_{s}", f"T_size_seed_{s}"]
    with open(path,"w",newline="") as f: csv.writer(f).writerow(header)

def append_row(path, row):
    with open(path,"a",newline="") as f: csv.writer(f).writerow(row)

# ---------------------- train one epoch ----------------------
def train_one_epoch(model, loader, optimizer, device, nb, scaler, amp_enabled):
    model.train(); totL=0.0; corr=0; n=0
    for b,(x,y) in enumerate(loader, start=1):
        x=x.to(device,non_blocking=nb); y=y.to(device,non_blocking=nb)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type=device.type, enabled=amp_enabled and device.type=="cuda"):
            logits=model(x)
            loss=F.cross_entropy(logits,y,label_smoothing=LABEL_SMOOTH)
        if device.type=="cuda" and amp_enabled:
            scaler.scale(loss).backward()
            if CLIP_MAX_NORM>0: scaler.unscale_(optimizer); torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_MAX_NORM)
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward()
            if CLIP_MAX_NORM>0: torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_MAX_NORM)
            optimizer.step()
        with torch.no_grad():
            totL += float(loss.item())*y.size(0); corr += (logits.argmax(1)==y).sum().item(); n += y.size(0)
        if LOG_INTERVAL and b%LOG_INTERVAL==0: print(f"  [Train] Batch {b:4d} | Loss {loss.item():.4f}")
    return totL/n, corr/n

# ---------------------- main ----------------------
def main():
    seed_all(SEED); device=get_device(); nb=non_blocking_ok(device); amp=(device.type=="cuda") and USE_AMP
    print(f"Using device: {device} | AMP: {amp}")

    (full_train, train_idx_np, val_idx_np, holdout_idx_np,
     train_loader, train_eval_loader, val_loader, hold_loader, test_loader) = build_dataloaders(device)

    train_ids, val_ids = train_idx_np.tolist(), val_idx_np.tolist()
    n_train, n_val = len(train_ids), len(val_ids)

    # load groupings + precompute n_i, |T| (non-empty)
    groupings = {"train":{}, "val":{}}
    n_i_lists = {"train":{}, "val":{}}
    T_sizes   = {"train":{}, "val":{}}
    for s in SEEDS:
        gtr=load_grouping("train",s); gva=load_grouping("val",s)
        groupings["train"][s]=gtr; groupings["val"][s]=gva
        all_tr=sorted([idx for k in range(K) for idx in gtr[str(k)]])
        all_va=sorted([idx for k in range(K) for idx in gva[str(k)]])
        assert all_tr==sorted(train_ids), f"Train grouping seed {s} not covering train20k exactly."
        assert all_va==sorted(val_ids),   f"Val grouping seed {s} not covering val10k exactly."
        n_i_lists["train"][s]=[len(gtr[str(k)]) for k in range(K)]
        n_i_lists["val"][s]  =[len(gva[str(k)]) for k in range(K)]
        T_sizes["train"][s]=sum(1 for ni in n_i_lists["train"][s] if ni>0)
        T_sizes["val"][s]  =sum(1 for ni in n_i_lists["val"][s]   if ni>0)

    # b depends on n only
    b_vals = {"train": compute_b(n_train, DELTA), "val": compute_b(n_val, DELTA)}

    # TD25 û per seed (uses |T|)
    uhat_td25 = {
        "train": {s: compute_uhat_td25(n_train, n_i_lists["train"][s], b_vals["train"], GAMMA_THEO, DELTA, T_sizes["train"][s]) for s in SEEDS},
        "val":   {s: compute_uhat_td25(n_val,   n_i_lists["val"][s],   b_vals["val"],   GAMMA_THEO, DELTA, T_sizes["val"][s])   for s in SEEDS},
    }

    # opt/sched
    model = SmallCNN(NUM_CLASSES).to(device)
    if OPTIMIZER.lower()!="adam": raise ValueError("Fixed optimizer is Adam.")
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    if SCHEDULER.lower()!="steplr": raise ValueError("Fixed scheduler is StepLR.")
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA_STEP)
    scaler = torch.amp.GradScaler("cuda", enabled=amp)

    # init CSVs
    init_basic_metrics_csv(METRICS_CSV)
    for _,p in BOUND_TD25.items(): init_bound_td25_csv(p, SEEDS)
    for _,p in BOUND_OLD.items():  init_bound_old_csv(p, SEEDS)

    best_test_acc=0.0
    hist={"train_loss":[], "train_acc":[], "test_loss":[], "test_acc":[]}

    for epoch in range(1, NUM_EPOCHS+1):
        print(f"\n=== Epoch {epoch}/{NUM_EPOCHS} ===")
        _trainL,_trainA = train_one_epoch(model, train_loader, optimizer, device, nb, scaler, amp)
        train_ce, train_acc = evaluate_ce_acc(model, train_eval_loader, device, nb, amp)
        val_ce,   val_acc   = evaluate_ce_acc(model, val_loader, device, nb, amp)
        test_ce,  test_acc  = evaluate_ce_acc(model, test_loader, device, nb, amp)
        scheduler.step(); lr = optimizer.param_groups[0]["lr"]

        append_row(METRICS_CSV, [epoch, f"{train_ce:.6f}", f"{train_acc:.6f}",
                                 f"{val_ce:.6f}", f"{val_acc:.6f}", f"{test_ce:.6f}", f"{test_acc:.6f}", f"{lr:.8f}"])
        hist["train_loss"].append(train_ce); hist["train_acc"].append(train_acc)
        hist["test_loss"].append(test_ce);  hist["test_acc"].append(test_acc)

        print(f"[Metrics] train_ce={train_ce:.4f} | train_acc={train_acc*100:.2f}% | "
              f"val_ce={val_ce:.4f} | val_acc={val_acc*100:.2f}% | "
              f"test_ce={test_ce:.4f} | test_acc={test_acc*100:.2f}% | lr={lr:.6f}")

        torch.save({"model_state":model.state_dict(),"test_acc":test_acc,"epoch":epoch}, LAST_CKPT_PATH)
        if SAVE_BEST and test_acc>best_test_acc:
            best_test_acc=test_acc
            torch.save({"model_state":model.state_dict(),"test_acc":best_test_acc,"epoch":epoch}, BEST_CKPT_PATH)
            print(f"  ↳ Saved new best to {BEST_CKPT_PATH} (test_acc={best_test_acc*100:.2f}%)")

        # per-index losses
        ce_tr, z1_tr = per_index_both_losses(model, device, full_train, train_ids, batch_size=1024)
        ce_va, z1_va = per_index_both_losses(model, device, full_train, val_ids,   batch_size=1024)

        A1 = {("train","01"): float(np.mean(list(z1_tr.values()))),
              ("train","ce"): float(np.mean(list(ce_tr.values()))),
              ("val","01")  : float(np.mean(list(z1_va.values()))),
              ("val","ce")  : float(np.mean(list(ce_va.values())))}

        C_vals = {"01":1.0, "ce":compute_C_holdout(model, device, hold_loader, "ce")}

        # ---- TD25 '(5)' rows ----
        for (whichS, loss_kind), out_csv in BOUND_TD25.items():
            ids = train_ids if whichS=="train" else val_ids
            n = len(ids)
            b = b_vals[whichS]
            uhat_by_seed = uhat_td25[whichS]
            grouping_for_S = groupings[whichS]
            per_idx = (z1_tr if whichS=="train" and loss_kind=="01" else
                       ce_tr if whichS=="train" else
                       z1_va if loss_kind=="01" else ce_va)
            A1_val = A1[(whichS, loss_kind)]
            C_used = C_vals[loss_kind]

            A2_s, A3_s, B5_s, Unc_s = {}, {}, {}, {}
            for s in SEEDS:
                g = grouping_for_S[s]
                # sum over i ∈ T (non-empty); empty clusters contribute 0 anyway
                F_Si=[]
                for k in range(K):
                    ids_k=g[str(k)]
                    if len(ids_k)==0: continue
                    vals=[per_idx[idx] for idx in ids_k]
                    F_Si.append(float(np.mean(vals)))
                A2 = (b/n) * float(np.sum(F_Si))
                A3 = compute_A3(C_used, uhat_by_seed[s], ALPHA, GAMMA_THEO, DELTA, n)
                B5 = A1_val + A2 + A3
                Unc = A2 + A3
                A2_s[s]=A2; A3_s[s]=A3; B5_s[s]=B5; Unc_s[s]=Unc

            row=[epoch, whichS, loss_kind, f"{A1_val:.6f}", f"{C_used:.6f}", f"{b:.6f}",
                 f"{GAMMA_THEO:.8f}", f"{ALPHA:.6f}", f"{DELTA:.6f}",
                 f"{np.mean(list(B5_s.values())):.6f}", f"{np.mean(list(Unc_s.values())):.6f}",
                 f"{train_ce:.6f}", f"{train_acc:.6f}",
                 f"{val_ce:.6f}", f"{val_acc:.6f}",
                 f"{test_ce:.6f}", f"{test_acc:.6f}", f"{lr:.8f}"]
            for s in SEEDS:
                row += [f"{A2_s[s]:.6f}", f"{A3_s[s]:.6f}", f"{B5_s[s]:.6f}", f"{Unc_s[s]:.6f}",
                        f"{uhat_by_seed[s]:.6f}", f"{T_sizes[whichS][s]:d}"]
            append_row(out_csv, row)

        # ---- OLD '(3)' rows ----
        for (whichS, loss_kind), out_csv in BOUND_OLD.items():
            ids = train_ids if whichS=="train" else val_ids
            n = len(ids)
            grouping_for_S = groupings[whichS]
            n_i_map = n_i_lists[whichS]
            Tsize_map = T_sizes[whichS]
            A1_val = A1[(whichS, loss_kind)]
            C_used = C_vals[loss_kind]

            term2_s, g2_s, B3_s, uhat_old_s = {}, {}, {}, {}
            for s in SEEDS:
                n_i = n_i_map[s]
                u_old = compute_uhat_old(n, n_i, GAMMA_THEO, DELTA, K)
                g2 = compute_g2_old(C_used, n, K, DELTA, Tsize_map[s], n_i)
                term2 = C_used * math.sqrt(max(0.0, u_old * ALPHA * math.log(GAMMA_THEO)))
                B3 = A1_val + term2 + g2
                uhat_old_s[s]=u_old; g2_s[s]=g2; term2_s[s]=term2; B3_s[s]=B3

            row=[epoch, whichS, loss_kind,
                 f"{A1_val:.6f}", f"{C_used:.6f}",
                 f"{GAMMA_THEO:.8f}", f"{ALPHA:.6f}", f"{DELTA:.6f}",
                 f"{np.mean(list(B3_s.values())):.6f}",
                 f"{np.mean(list(term2_s.values())):.6f}",
                 f"{np.mean(list(g2_s.values())):.6f}",
                 f"{train_ce:.6f}", f"{train_acc:.6f}",
                 f"{val_ce:.6f}", f"{val_acc:.6f}",
                 f"{test_ce:.6f}", f"{test_acc:.6f}", f"{lr:.8f}"]
            for s in SEEDS:
                row += [f"{term2_s[s]:.6f}", f"{g2_s[s]:.6f}", f"{B3_s[s]:.6f}",
                        f"{uhat_old_s[s]:.6f}", f"{T_sizes[whichS][s]:d}"]
            append_row(out_csv, row)

    # plots
    epochs=range(1, NUM_EPOCHS+1)
    plt.figure(figsize=(8,5)); plt.plot(epochs,hist["train_loss"],label="Train CE (eval)")
    plt.plot(epochs,hist["test_loss"],label="Test CE"); plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.title("FMNIST: Train vs Test Loss"); plt.legend(); plt.grid(True); plt.tight_layout(); plt.savefig(LOSS_PLOT,dpi=150)
    plt.figure(figsize=(8,5)); plt.plot(epochs,[a*100 for a in hist["train_acc"]],label="Train Acc (eval)")
    plt.plot(epochs,[a*100 for a in hist["test_acc"]],label="Test Acc"); plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)")
    plt.title("FMNIST: Train vs Test Accuracy"); plt.legend(); plt.grid(True); plt.tight_layout(); plt.savefig(ACC_PLOT,dpi=150); plt.show()

    if device.type=="cuda": torch.cuda.empty_cache()
    elif device.type=="mps":
        try: torch.mps.empty_cache()
        except Exception: pass

    print("\nTraining complete. Wrote:")
    print(f"  - {METRICS_CSV}")
    for _,f in BOUND_TD25.items(): print(f"  - {f}")
    for _,f in BOUND_OLD.items():  print(f"  - {f}")
    if os.path.exists(BEST_CKPT_PATH): print(f"Best checkpoint: {BEST_CKPT_PATH}")

if __name__ == "__main__":
    main()
