import os
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import pickle
from tqdm import tqdm
import numpy as np
from sklearn.metrics import silhouette_score
from scipy.cluster.hierarchy import linkage, cophenet
from scipy.spatial.distance import pdist, squareform
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from hflayers import HopfieldLayer
from hyper_hflayers import Hyperbolic_HopfieldLayer as Hyper_HopfieldLayer  # ensure the module path is correct
from Uhop import LearnableHopfield,uniform_loss

# ---------------- Poincaré Ball ops (curvature = -c, c>0) ----------------
class PoincareBall:
    def __init__(self, c: float, eps: float = 1e-6):
        self.c = float(c)
        self.eps = eps
        self.radius = 1.0 / (self.c ** 0.5)
        self._M = 1.0 - 1e-5  # safety margin factor

    def _proj_with_margin(self, x):
        """Project to the ball with a small margin to avoid numerical issues near the boundary."""
        r = x.norm(dim=-1, keepdim=True).clamp_min(self.eps)
        max_r = self._M * self.radius
        scale = torch.where(r > max_r, max_r / r, torch.ones_like(r))
        return x * scale

    def lambda_x(self, x):
        """Conformal factor λ(x)."""
        x2 = (x * x).sum(dim=-1, keepdim=True)
        return 2.0 / (1.0 - self.c * x2).clamp_min(1e-6)

    def mobius_add(self, x, y):
        """Möbius addition on the ball."""
        c = self.c
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        xy = (x * y).sum(dim=-1, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        den = (1 + 2 * c * xy + (c ** 2) * x2 * y2).clamp_min(1e-6)
        out = num / den
        return self._proj_with_margin(out)

    def mobius_neg(self, x):
        return self._proj_with_margin(-x)

    def exp0(self, v):
        """Exponential map at the origin."""
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = (self.c ** 0.5) * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        x = coef * v
        return self._proj_with_margin(x)

    def log0(self, x):
        """Logarithmic map at the origin."""
        x = self._proj_with_margin(x)
        xnorm = x.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * xnorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = torch.atanh(arg) / ((self.c ** 0.5) * xnorm)
        return coef * x

    def exp_p(self, p, v):
        """Exponential map at p."""
        lam = self.lambda_x(p)
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = 0.5 * (self.c ** 0.5) * lam * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        delta = coef * v
        return self.mobius_add(p, delta)

    def log_p(self, p, x):
        """Logarithmic map at p."""
        lam = self.lambda_x(p)
        y = self.mobius_add(self.mobius_neg(p), x)
        ynorm = y.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * ynorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = (2.0 / ((self.c ** 0.5) * lam)) * (torch.atanh(arg) / ynorm)
        return coef * y

    def dist(self, x, y):
        """Geodesic distance on the ball."""
        c = self.c
        diff2 = ((x - y) ** 2).sum(dim=-1, keepdim=True)
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        num = 2 * c * diff2
        den = ((1 - c * x2) * (1 - c * y2)).clamp_min(self.eps)
        z = 1 + num / den
        return torch.acosh(z.clamp_min(1 + 1e-6))


# ---------------- Hyperbolic Attention (not used here but provided as requested) ----------------
class HyperbolicAttentionLayer(nn.Module):
    """
    Hyperbolic attention with learnable scales, temperature and gated residual:
      - project q/k/v in tangent space, push onto the ball with learnable scales;
      - score with -d^2 or -d, normalize via z-score and temperature, softmax;
      - Frechet-weighted sum in query's tangent space and map back;
      - gated residual to keep gradients flowing early.
    """
    def __init__(self, feat_dim: int, n_mem: int, c: float = 1.0,
                 tau_init: float = 5.0, dropout: float = 0.0,
                 score_center: str = "none",    # 'none' | 'min' | 'mean'
                 use_sqdist: bool = True,
                 layernorm: bool = True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.use_sqdist = use_sqdist
        self.score_center = score_center

        self.q_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.k_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.v_proj = nn.Linear(feat_dim, feat_dim, bias=False)

        self.mem = nn.Parameter(torch.randn(n_mem, feat_dim) * 0.02)

        self.ln = nn.LayerNorm(feat_dim) if layernorm else nn.Identity()
        self.dropout = nn.Dropout(dropout)

        self.log_tau = nn.Parameter(torch.log(torch.tensor(float(tau_init))))

        self.q_log_scale = nn.Parameter(torch.tensor(0.0))
        self.k_log_scale = nn.Parameter(torch.tensor(1.1))
        self.v_log_scale = nn.Parameter(torch.tensor(1.1))

        self.alpha_skip = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        B, L, D = x.shape

        x_ln = self.ln(x)
        q_tan = self.q_proj(x_ln)
        k_tan = self.k_proj(self.mem)
        v_tan = self.v_proj(self.mem)

        s_q = torch.exp(self.q_log_scale).clamp(0.5, 10.0)
        s_k = torch.exp(self.k_log_scale).clamp(0.5, 10.0)
        s_v = torch.exp(self.v_log_scale).clamp(0.5, 10.0)

        q_ball = self.ball.exp0(s_q * q_tan)
        k_ball = self.ball.exp0(s_k * k_tan)
        v_ball = self.ball.exp0(s_v * v_tan)

        q_rep = q_ball.repeat(1, k_ball.shape[0], 1)
        k_rep = k_ball.unsqueeze(0).expand(B, -1, -1)
        d = self.ball.dist(q_rep, k_rep).squeeze(-1)
        scores = -(d ** 2) if self.use_sqdist else -d

        scores = scores - scores.max(dim=-1, keepdim=True).values
        mu = scores.mean(dim=-1, keepdim=True)
        sigma = scores.std(dim=-1, keepdim=True).clamp_min(1e-6)
        scores = (scores - mu) / sigma

        tau = torch.exp(self.log_tau).clamp(0.05, 50.0)
        scores = (tau * scores).clamp(-20.0, 20.0)

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        v_ball_exp = v_ball.unsqueeze(0).expand(B, -1, -1)
        q_center = q_ball.squeeze(1)
        log_q_v = self.ball.log_p(q_center.unsqueeze(1), v_ball_exp)
        z_tan_q = torch.bmm(attn.unsqueeze(1), log_q_v).squeeze(1)
        z_ball = self.ball.exp_p(q_center, z_tan_q)

        out_tan0 = self.ball.log0(z_ball).unsqueeze(1)

        q_tan0 = self.ball.log0(q_ball)
        gate = torch.sigmoid(self.alpha_skip)
        out_tan0 = out_tan0 + gate * q_tan0

        return out_tan0 if not squeeze else out_tan0.squeeze(1)


# ---------------- Möbius Linear (not used here but provided as requested) ----------------
class MobiusLinear(nn.Module):
    def __init__(self, in_dim, out_dim, c=1.0, bias=True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.lin = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.kaiming_uniform_(self.lin.weight, a=0.2)
        if bias:
            nn.init.zeros_(self.lin.bias)

    def forward(self, x_ball):
        x_ball = self.ball._proj_with_margin(x_ball)
        x_tan = self.ball.log0(x_ball)
        x_tan = torch.clamp(x_tan, -20.0, 20.0)
        y_tan = self.lin(x_tan)
        y_tan = torch.clamp(y_tan, -20.0, 20.0)
        y_ball = self.ball.exp0(y_tan)
        return y_ball


# ---------------- Lightweight Hyperbolic NN block (not used here but provided as requested) ----------------
class HypNNBlock(nn.Module):
    """Two MobiusLinear layers with LN+ReLU in tangent; outputs a tangent vector."""
    def __init__(self, feat_dim: int, hidden: int = 512, c: float = 1.0, dropout: float = 0.1):
        super().__init__()
        self.ball = PoincareBall(c)
        self.fc1 = MobiusLinear(feat_dim, hidden, c=c)
        self.fc2 = MobiusLinear(hidden, feat_dim, c=c)
        self.ln1 = nn.LayerNorm(hidden)
        self.ln2 = nn.LayerNorm(feat_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        x = nn.functional.layer_norm(x, x.shape[-1:])
        x_ball = self.ball.exp0(x)

        h1_ball = self.fc1(x_ball)
        h1_tan = self.ball.log0(h1_ball)
        h1_tan = self.ln1(h1_tan)
        h1_tan = torch.clamp(h1_tan, -20.0, 20.0)
        h1_tan = F.relu(h1_tan)
        h1_ball = self.ball.exp0(h1_tan)
        h1_ball = self.dropout(h1_ball)

        h2_ball = self.fc2(h1_ball)
        out_tan = self.ball.log0(h2_ball)
        out_tan = self.ln2(out_tan)
        out_tan = torch.clamp(out_tan, -20.0, 20.0)

        out = out_tan if not squeeze else out_tan.squeeze(1)
        return out

class UHopBlock(nn.Module):

    def __init__(
        self,
        feat_dim: int,
        n_mem: int,
        n_heads: int = 4,
        dropout: float = 0.1,
        mode: str = "softmax",
        kernel: str = "lin",
        scale: float | None = None,
    ):
        super().__init__()
        self.hop = LearnableHopfield(
            d_model=feat_dim,
            n_heads=n_heads,
            d_keys=None,
            d_values=None,
            mix=True,
            update_steps=1,
            dropout=dropout,
            mode=mode,
            kernel=kernel,
            scale=scale,
        )

        self.memory = nn.Parameter(torch.randn(1, n_mem, feat_dim) * 0.02)
        self.ln = nn.LayerNorm(feat_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,D] or [B,1,D]
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        B, L, D = x.shape  # L=1
        mem = self.memory.expand(B, -1, -1)          # [B, n_mem, D]

        out = self.hop(x, mem)                       # [B,1,D]
        out = self.ln(out)

        return out.squeeze(1) if squeeze else out

    def kernel_forward(self, Y: torch.Tensor) -> torch.Tensor:

        if Y.dim() == 2:
            Y = Y.unsqueeze(0)   # [1,N,D]
        return self.hop.uniform_forward(Y)  # LearnableHopfield.uniform_forward



# ---------------- Dataset wrappers ----------------
# Mapping from CIFAR-100 coarse labels to super-classes: coarse_to_super[fine_label] = super_class_label
coarse_to_super = {
    # “Animate” super-class
    0: 0,   # aquatic_mammals         → aquatic
    1: 0,   # fish                    → aquatic

    7: 1,   # insects                 → small
    13: 1,  # non_insect_invertebrates → small
    16: 1,  # small_mammals           → small

    8: 2,   # large_carnivores        → large
    11: 2,  # large_omnivores         → large
    12: 2,  # medium_mammals          → large
    14: 2,  # people                  → large
    15: 2,  # reptiles                → large

    # “Plants” super-class
    2: 3,   # flowers                 → plants
    4: 3,   # fruits_and_vegetables   → plants
    17: 3,  # trees                   → plants

    # “Household & Furniture” super-class
    3: 4,   # food_containers         → home_items
    5: 4,   # household_electrical_devices → home_items
    6: 4,   # household_furniture     → home_items

    # “Vehicles” super-class
    9: 5,   # large_man_made_outdoor_things → vehicles
    18: 5,  # vehicles_1              → vehicles
    19: 5,  # vehicles_2              → vehicles

    # “Natural Scenes” super-class
    10: 6   # large_natural_outdoor_scenes → natural_scenes
}


class ThreeLevelCIFAR100(datasets.CIFAR100):
    """Return (image, fine_label, coarse_label, super_coarse_label)."""
    def __init__(self, root, train, transform, coarse_labels, download=False):
        super().__init__(root, train=train, transform=transform, download=download)
        assert coarse_labels is not None
        self.coarse_labels = coarse_labels

    def __getitem__(self, idx):
        img, fine = super().__getitem__(idx)
        coarse = self.coarse_labels[idx]
        super_coarse = coarse_to_super[coarse]
        return img, fine, coarse, super_coarse


# ---------------- Model ----------------
class HierarchicalCNN(nn.Module):
    def __init__(self, baseline: str = "HAMN", hyper_c=1.0, clip_r=0.9, lr=1.0, quantity_n: int = 1, tau: float = 1.0):
        super().__init__()
        # ResNet18 backbone (no final FC)
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()
        feat_dim = 512

        quantity_fine = int(100 * quantity_n)

        if baseline == "HAMN":
            # Hyperbolic Hopfield (HAMN) module implemented earlier
            self.block = Hyper_HopfieldLayer(
                input_size=feat_dim,
                hidden_size=feat_dim,
                output_size=feat_dim,
                num_heads=4,
                hyper_c=hyper_c,
                clip_r=1.0 / math.sqrt(hyper_c) * clip_r,
                lr=lr,
                association_activation='relu',
                dropout=0.2,
                quantity=quantity_fine,
                batch_first=True,
                train_c=False,
                input_as_hyper=False,
                out_as_hyper=False
            )
        elif baseline == "HypAttn":
            # Strong baseline 1: Hyperbolic Attention
            self.block = HyperbolicAttentionLayer(
                feat_dim=feat_dim,
                n_mem=quantity_fine,
                c=hyper_c,
                tau_init=tau,
                dropout=0.1
            )
        elif baseline == "HypNN":
            # Strong baseline 2: Lightweight Hyperbolic NN block
            self.block = HypNNBlock(
                feat_dim=feat_dim,
                hidden=512,
                c=hyper_c,
                dropout=0.1
            )
        elif baseline == "MHN_Euc":
            # Euclidean MHN/Hopfield reference
            self.block = HopfieldLayer(
                input_size=feat_dim,
                hidden_size=feat_dim,
                output_size=feat_dim,  
                num_heads=8,
                scaling=0.01,
                dropout=0.2,
                association_activation='relu',
                quantity=quantity_fine,
                batch_first=True,
            )
        elif baseline == "UHop":
            # U-Hop baseline：LearnableHopfield
            self.block = UHopBlock(
                feat_dim=feat_dim,
                n_mem=quantity_fine, 
                n_heads=4,
                dropout=0.1,
                mode="softmax",
                kernel="lin",
                scale=None,
            )
        else:
            raise ValueError(f"Unknown baseline: {baseline}")

        # Heads for three levels
        self.head_super_coarse = nn.Linear(feat_dim, 7)
        self.head_coarse = nn.Linear(feat_dim, 20)
        self.head_fine = nn.Linear(feat_dim, 100)

    def forward(self, x):
        feats = self.backbone(x)  # [B, 512]
        feats = feats.unsqueeze(1)  # [B,1,512] unify interface

        # All blocks return [B,1,512] (tangent vector)
        feats_out = self.block(feats).squeeze(1)  # [B,512]

        return (
            self.head_super_coarse(feats_out),
            self.head_coarse(feats_out),
            self.head_fine(feats_out)
        )

# ---------------- Train / Eval ----------------
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    prog = tqdm(train_loader, desc="Training", unit="batch")
    for imgs, fine, coarse, super_coarse in prog:
        imgs, fine, coarse, super_coarse = imgs.to(device), fine.to(device), coarse.to(device), super_coarse.to(device)

        # forward
        super_coarse_out, coarse_out, fine_out = model(imgs)

        # individual losses
        loss_super_coarse = criterion(super_coarse_out, super_coarse)
        loss_coarse = criterion(coarse_out, coarse)
        loss_fine = criterion(fine_out, fine)

        # joint loss
        loss = loss_super_coarse + loss_coarse + loss_fine

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()
        prog.set_postfix(loss=total_loss / (prog.n + 1))

    print(f'Train Loss: {total_loss / len(train_loader):.4f}')


def test(model, test_loader, device):
    model.eval()
    correct_super_coarse = correct_c = correct_f = total = 0
    prog = tqdm(test_loader, desc="Testing", unit="batch")

    with torch.no_grad():
        for imgs, fine, coarse, super_coarse in prog:
            imgs, fine, coarse, super_coarse = imgs.to(device), fine.to(device), coarse.to(device), super_coarse.to(device)

            super_coarse_out, coarse_out, fine_out = model(imgs)

            pred_super_coarse = super_coarse_out.argmax(dim=1)
            pred_c = coarse_out.argmax(dim=1)
            pred_f = fine_out.argmax(dim=1)

            correct_super_coarse += (pred_super_coarse == super_coarse).sum().item()
            correct_c += (pred_c == coarse).sum().item()
            correct_f += (pred_f == fine).sum().item()
            total += fine.size(0)

            acc_super_coarse = correct_super_coarse / total
            acc_c = correct_c / total
            acc_f = correct_f / total
            prog.set_postfix(super_coarse_acc=f"{acc_super_coarse:.4f}", coarse_acc=f"{acc_c:.4f}", fine_acc=f"{acc_f:.4f}")

    print(f'Super Coarse Acc: {correct_super_coarse / total:.4f}, Coarse Acc: {correct_c / total:.4f}, Fine Acc: {correct_f / total:.4f}')
    return acc_super_coarse, acc_c, acc_f

def learn_kernel_uhop(model, train_loader, device, kernel_epoch: int = 50):

    assert isinstance(model.block, UHopBlock), "learn_kernel_uhop in UHopBlock"

    kernel_params = list(model.block.hop.kernel.parameters())
    opt = torch.optim.SGD(kernel_params, lr=0.1, momentum=0.9)

    model.train()
    for epoch in tqdm(range(kernel_epoch), desc="UHop kernel pretrain (epochs)", unit="epoch"):
        unif_losses = []
        prog = tqdm(train_loader, desc=f"kernel epoch {epoch+1}/{kernel_epoch}", unit="batch", leave=False)
        for imgs, fine, coarse, super_coarse in prog:
            imgs = imgs.to(device)

            opt.zero_grad()

            with torch.no_grad():
                feats = model.backbone(imgs)   # [B, 512]
            feats = feats.detach()             

   
            Y = feats.unsqueeze(0)            # [1, B, 512]
            memory = model.block.kernel_forward(Y)  # [1, B, 512]

            s = memory.squeeze(0)             # [B, 512]
            s = F.normalize(s, dim=-1)        
            loss = uniform_loss(s)           

            loss.backward()
            opt.step()
            unif_losses.append(loss.item())

        print(f"[UHop kernel pretrain] epoch {epoch+1}/{kernel_epoch}, "
              f"uniform loss = {np.mean(unif_losses):.4f}")


def test_with_structure_metrics(model, test_loader, device):
    """Evaluate flat accuracies and structure metrics (silhouette, cophenetic, ultrametricity)."""
    model.eval()
    feats_list, fine_list, coarse_list, super_coarse_list = [], [], [], []
    correct_super_c = correct_c = correct_f = total = 0

    with torch.no_grad():
        for imgs, fine, coarse, super_coarse in tqdm(test_loader, desc="Evaluating", unit="batch"):
            imgs, fine, coarse, super_coarse = imgs.to(device), fine.to(device), coarse.to(device), super_coarse.to(device)

            super_coarse_out, coarse_out, fine_out = model(imgs)
            pred_super_c = super_coarse_out.argmax(dim=1)
            pred_c = coarse_out.argmax(dim=1)
            pred_f = fine_out.argmax(dim=1)

            correct_super_c += (pred_super_c == super_coarse).sum().item()
            correct_c += (pred_c == coarse).sum().item()
            correct_f += (pred_f == fine).sum().item()
            total += fine.size(0)

            # collect embeddings
            feats = model.backbone(imgs)
            feats_list.append(feats.cpu().numpy())
            fine_list.extend(fine.cpu().numpy())
            coarse_list.extend(coarse.cpu().numpy())
            super_coarse_list.extend(super_coarse.cpu().numpy())

    flat_super_coarse_acc = correct_super_c / total
    flat_coarse_acc = correct_c / total
    flat_fine_acc = correct_f / total
    print(f'Flat Super Coarse Acc: {flat_super_coarse_acc:.4f}, Flat Coarse Acc: {flat_coarse_acc:.4f}, Flat Fine Acc: {flat_fine_acc:.4f}')

    feats = np.concatenate(feats_list, axis=0)
    fine_labels = np.array(fine_list)
    coarse_labels = np.array(coarse_list)
    super_coarse_labels = np.array(super_coarse_list)

    # silhouette scores
    sil_super_coarse = silhouette_score(feats, super_coarse_labels)
    sil_coarse = silhouette_score(feats, coarse_labels)
    sil_fine = silhouette_score(feats, fine_labels)
    print(f"Silhouette Super Coarse: {sil_super_coarse:.4f}, Silhouette Coarse: {sil_coarse:.4f}, Silhouette Fine: {sil_fine:.4f}")

    # cophenetic correlation
    Z = linkage(feats, method='average')
    coph_corr, _ = cophenet(Z, pdist(feats))
    print(f"Cophenetic Correlation: {coph_corr:.4f}")

    # ultrametricity via random triplets
    D = squareform(pdist(feats))
    N = D.shape[0]
    viol = 0
    num_samples = min(10000, N*(N-1)*(N-2)//6)
    rng = np.random.default_rng(42)
    for _ in range(num_samples):
        i, j, k = rng.choice(N, size=3, replace=False)
        if D[i, j] > max(D[i, k], D[j, k]):
            viol += 1
    ultra_ratio = 1.0 - viol/num_samples
    print(f"Ultrametricity (1 - violation_rate): {ultra_ratio:.4f}")

    return flat_super_coarse_acc, flat_coarse_acc, flat_fine_acc, sil_super_coarse, sil_coarse, sil_fine, coph_corr, ultra_ratio


# ---------------- Main ----------------
def main(baseline="HAMN", hyper_c=1.0, clip_r=0.9, lr=1.0, quantity_n: int = 1, tau: float = 1.0,
         checkpoint_path: str = "checkpoints.pth"):
    # Data (fine + coarse + super-coarse)
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])
    train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_ds = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    # raw coarse_labels from CIFAR-100 python files
    train_raw = pickle.load(open('./data/cifar-100-python/train', 'rb'), encoding='bytes')
    test_raw  = pickle.load(open('./data/cifar-100-python/test',  'rb'), encoding='bytes')
    train_coarse = train_raw[b'coarse_labels']
    test_coarse  = test_raw[b'coarse_labels']

    # wrap to return (img, fine, coarse, super_coarse)
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071,0.4867,0.4408),(0.2675,0.2565,0.2761))
    ])
    train_ds = ThreeLevelCIFAR100('./data', train=True,  transform=transform, coarse_labels=train_coarse)
    test_ds  = ThreeLevelCIFAR100('./data', train=False, transform=transform, coarse_labels=test_coarse)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=4)
    test_loader  = DataLoader(test_ds,  batch_size=128, shuffle=False, num_workers=4)

    model = HierarchicalCNN(baseline=baseline, hyper_c=hyper_c, clip_r=clip_r, lr=lr, quantity_n=quantity_n).to(device)
    

    if baseline == "UHop":

        kernel_ckpt = os.path.join(os.path.dirname(checkpoint_path),
                                   "uhop_kernel_only.pth")

        if os.path.isfile(kernel_ckpt):

            print("=> loading UHop kernel checkpoint", kernel_ckpt)
            model.load_state_dict(torch.load(kernel_ckpt, map_location=device))
        else:

            print("=> UHop baseline: pretraining kernel with uniform loss (U-Hop style)")
            learn_kernel_uhop(model, train_loader, device, kernel_epoch=50)

            os.makedirs(os.path.dirname(kernel_ckpt), exist_ok=True)
            torch.save(model.state_dict(), kernel_ckpt)
            print("=> saved UHop kernel checkpoint to", kernel_ckpt)

    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint", checkpoint_path)
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    # ---- AdamW with split weight decay (same optimizer scheme as requested) ----
    def _split_decay(named_params):
        decay, no_decay = [], []
        for n, p in named_params:
            if not p.requires_grad:
                continue
            # do not apply weight decay to biases and (Layer/Batch)Norm params
            if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower() or "norm" in n.lower():
                no_decay.append(p)
            else:
                decay.append(p)
        return decay, no_decay

    back_named = list(model.backbone.named_parameters())
    head_named = (
        list(model.head_super_coarse.named_parameters()) +
        list(model.head_coarse.named_parameters()) +
        list(model.head_fine.named_parameters())
    )
    block_named = list(model.block.named_parameters())

    decay_base, nodecay_base = _split_decay(back_named + head_named)
    decay_block, nodecay_block = _split_decay(block_named)

    optimizer = optim.AdamW(
        [
            {"params": decay_base,    "lr": 1e-3, "weight_decay": 5e-4},
            {"params": nodecay_base,  "lr": 1e-3, "weight_decay": 0.0},
            {"params": decay_block,   "lr": 3e-4, "weight_decay": 1e-4},
            {"params": nodecay_block, "lr": 3e-4, "weight_decay": 0.0},
        ],
        betas=(0.9, 0.999), eps=1e-8
    )
    criterion = nn.CrossEntropyLoss()

    best_sum = -1.0

    for epoch in range(1, 30):
        print(f"Epoch {epoch}")
        train_epoch(model, train_loader, optimizer, criterion, device)
        flat_super_coarse_acc, flat_coarse_acc, flat_fine_acc = test(model, test_loader, device)

        cur_sum = flat_super_coarse_acc + flat_coarse_acc + flat_fine_acc
        if cur_sum > best_sum:
            best_sum = cur_sum
            torch.save(model.state_dict(), checkpoint_path)
            print(f"=> New best sum={best_sum:.4f}, saved to {checkpoint_path}")

    print("=> loading best checkpoint", checkpoint_path)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    flat_super_coarse_acc, flat_coarse_acc, flat_fine_acc, sil_super_coarse, sil_coarse, sil_fine, coph_corr, ultra_ratio = \
        test_with_structure_metrics(model, test_loader, device)

    return (flat_super_coarse_acc, flat_coarse_acc, flat_fine_acc,
            sil_super_coarse, sil_coarse, sil_fine, coph_corr, ultra_ratio, test_loader)


if __name__ == '__main__':
    # Loop over multiple baselines; edit this list as needed

    baselines = ["UHop","HAMN", "HypAttn", "HypNN", "MHN_Euc"]

    best_params = {
        "hyper_c":    [1.0],
        "clip_r":     [0.9],
        "lr":         [1.0],
        "quantity_n": [2],
        "tau":        [1.0],
    }

    # Make a root result directory once
    os.makedirs("hyper_deem3_result", exist_ok=True)

    # (Optional) collect summaries for all baselines
    all_summaries = {}

    for baseline in baselines:
        print(f"\n========== Running baseline: {baseline} ==========")

        best_score = -1.0
        best_checkpoint = ""

        # Per-baseline subdir to avoid file collisions
        ckpt_dir = f"hyper_deem3_result/{baseline}/best_checkpoints"
        os.makedirs(ckpt_dir, exist_ok=True)

        metrics = []
        for i in range(10):
            hyper_c    = best_params["hyper_c"][0]
            clip_r     = best_params["clip_r"][0]
            lr         = best_params["lr"][0]
            quantity_n = best_params["quantity_n"][0]
            tau        = best_params["tau"][0]

            run_ckpt = (
                f"{ckpt_dir}/"
                f"{baseline}_hc{hyper_c}_cr{clip_r}"
                f"_lr{lr}_qn{quantity_n}"
                f"_run{i+1}_tau{tau}.pth"
            )

            (flat_super_coarse_acc, flat_coarse_acc, flat_fine_acc,
             sil_super_coarse, sil_coarse, sil_fine, coph_corr, ultra_ratio,
             test_loader) = main(
                baseline=baseline,
                hyper_c=hyper_c,
                clip_r=clip_r,
                lr=lr,
                quantity_n=quantity_n,
                tau=tau,
                checkpoint_path=run_ckpt
            )

            metrics.append([
                flat_super_coarse_acc,
                flat_coarse_acc,
                flat_fine_acc,
                sil_super_coarse,
                sil_coarse,
                sil_fine,
                coph_corr,
                ultra_ratio
            ])

            # Track the best checkpoint by fine accuracy (same criterion as before)
            if flat_fine_acc > best_score:
                best_score = flat_fine_acc
                best_checkpoint = run_ckpt

        metrics = np.array(metrics)  # shape (10, 8)
        means = metrics.mean(axis=0)
        stds  = metrics.std(axis=0)

        summary = {
            "baseline": baseline,
            "param_set": best_params,
            "means": {
                "flat_super_coarse_acc": means[0],
                "flat_coarse_acc":       means[1],
                "flat_fine_acc":         means[2],
                "sil_super_coarse":      means[3],
                "sil_coarse":            means[4],
                "sil_fine":              means[5],
                "coph_corr":             means[6],
                "ultra_ratio":           means[7]
            },
            "stds": {
                "flat_super_coarse_acc": stds[0],
                "flat_coarse_acc":       stds[1],
                "flat_fine_acc":         stds[2],
                "sil_super_coarse":      stds[3],
                "sil_coarse":            stds[4],
                "sil_fine":              stds[5],
                "coph_corr":             stds[6],
                "ultra_ratio":           stds[7]
            },
            "best_checkpoint": best_checkpoint,
            "best_score_by_fine_acc": best_score
        }

        all_summaries[baseline] = summary

        # Save a per-baseline Result.txt (same format as before)
        result_path = f"hyper_deem3_result/{baseline}/Result.txt"
        os.makedirs(os.path.dirname(result_path), exist_ok=True)
        with open(result_path, "a") as f:
            f.write("\n\nRepeat 10 runs summary:\n")
            json.dump(summary, f, indent=2)

        print(f"[{baseline}] Repeated evaluation done. Summary saved to {result_path}")

    # (Optional) also save a combined summary file
    with open("hyper_deem3_result/summary_all_baselines.json", "w") as f:
        json.dump(all_summaries, f, indent=2)
    print("All baselines finished. Combined summary saved to hyper_deem3_result/summary_all_baselines.json")

