import os
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import pickle
from tqdm import tqdm
import numpy as np
from sklearn.metrics import silhouette_score
from scipy.cluster.hierarchy import linkage, cophenet
from scipy.spatial.distance import pdist, squareform
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from hflayers import HopfieldLayer
from hyper_hflayers import Hyperbolic_HopfieldLayer as Hyper_HopfieldLayer # ensure the module path is correct
from Uhop import LearnableHopfield,uniform_loss

# Wrap dataset to return coarse label
from torchvision import datasets

import torch.nn.functional as F


# ---------------- Poincaré Ball ops (curvature = -c, c>0) ----------------
class PoincareBall:
    def __init__(self, c: float, eps: float = 1e-6):
        self.c = float(c)
        self.eps = eps
        self.radius = 1.0 / (self.c ** 0.5)
        self._M = 1.0 - 1e-5  # safety margin: r*(1-1e-5)

    def _proj_with_margin(self, x):
        # Force all points to lie strictly inside r*(1-1e-5) to avoid atanh(1) / λ(x) blow-ups
        r = x.norm(dim=-1, keepdim=True).clamp_min(self.eps)
        max_r = self._M * self.radius
        scale = torch.where(r > max_r, max_r / r, torch.ones_like(r))
        return x * scale

    def lambda_x(self, x):
        x2 = (x * x).sum(dim=-1, keepdim=True)
        return 2.0 / (1.0 - self.c * x2).clamp_min(1e-6)

    def mobius_add(self, x, y):
        c = self.c
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        xy = (x * y).sum(dim=-1, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        den = (1 + 2 * c * xy + (c ** 2) * x2 * y2).clamp_min(1e-6)
        out = num / den
        return self._proj_with_margin(out)

    def mobius_neg(self, x):
        return self._proj_with_margin(-x)

    def exp0(self, v):
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = (self.c ** 0.5) * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        x = coef * v
        return self._proj_with_margin(x)

    def log0(self, x):
        # clamp the argument of atanh to ensure < 1
        x = self._proj_with_margin(x)
        xnorm = x.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * xnorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = torch.atanh(arg) / ((self.c ** 0.5) * xnorm)
        return coef * x

    def exp_p(self, p, v):
        lam = self.lambda_x(p)
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = 0.5 * (self.c ** 0.5) * lam * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        delta = coef * v
        return self.mobius_add(p, delta)

    def log_p(self, p, x):
        lam = self.lambda_x(p)
        y = self.mobius_add(self.mobius_neg(p), x)
        ynorm = y.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * ynorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = (2.0 / ((self.c ** 0.5) * lam)) * (torch.atanh(arg) / ynorm)
        return coef * y

    def dist(self, x, y):
        # d_Dc(x,y) = arcosh(1 + 2c||x-y||^2 / ((1-c||x||^2)(1-c||y||^2)))
        c = self.c
        diff2 = ((x - y) ** 2).sum(dim=-1, keepdim=True)
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        num = 2 * c * diff2
        den = ((1 - c * x2) * (1 - c * y2)).clamp_min(self.eps)
        z = 1 + num / den
        return torch.acosh(z.clamp_min(1 + 1e-6))  # numerical safety


# ---------------- Hyperbolic Attention (baseline) ----------------
class HyperbolicAttentionLayer(nn.Module):
    """
    Stable + learnable scaling + gated residual hyperbolic attention:
      - q/k/v are first projected in the tangent space, then pushed farther on the ball by a learnable scale
        (increasing distance dynamic range);
      - scoring: -d^2  → (optional centering) → z-score → learnable temperature → clamp → softmax;
      - Frechet-weighted aggregation in each query's tangent space: z = exp_q( Σ a_i * log_q(v_i) );
      - output is a tangent vector, with a learnable residual added to the original q (keeps gradients flowing early).
    """
    def __init__(self, feat_dim: int, n_mem: int, c: float = 1.0,
                 tau_init: float = 5.0, dropout: float = 0.0,
                 score_center: str = "none",    # 'none' | 'min' | 'mean'
                 use_sqdist: bool = True,
                 layernorm: bool = True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.use_sqdist = use_sqdist
        self.score_center = score_center

        # Linear projections
        self.q_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.k_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.v_proj = nn.Linear(feat_dim, feat_dim, bias=False)

        # Memory vectors (tangent-space params), small-scale init
        self.mem = nn.Parameter(torch.randn(n_mem, feat_dim) * 0.02)

        # LayerNorm / Dropout
        self.ln = nn.LayerNorm(feat_dim) if layernorm else nn.Identity()
        self.dropout = nn.Dropout(dropout)

        # Learnable temperature (log-param for stability)
        self.log_tau = nn.Parameter(torch.log(torch.tensor(float(tau_init))))

        # Learnable scales (push points further from origin to enlarge distance dynamics)
        self.q_log_scale = nn.Parameter(torch.tensor(0.0))  # exp->1.0
        self.k_log_scale = nn.Parameter(torch.tensor(1.1))  # exp->~3.0
        self.v_log_scale = nn.Parameter(torch.tensor(1.1))  # exp->~3.0

        # Residual gate (tanh in [-1,1]), initialized near 0 to provide a small direct gradient path early
        self.alpha_skip = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        """
        Stable + learnable scaling + non-blocking gated residual
        Input [B,D] or [B,1,D], return [B,1,D] (tangent space)
        """
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        B, L, D = x.shape

        # Light normalization
        x_ln = self.ln(x)  # [B,1,D]
        q_tan = self.q_proj(x_ln)  # [B,1,D]
        k_tan = self.k_proj(self.mem)  # [N,D]
        v_tan = self.v_proj(self.mem)  # [N,D]

        # learnable scales (push points away from origin to enlarge distance range)
        s_q = torch.exp(self.q_log_scale).clamp(0.5, 10.0)
        s_k = torch.exp(self.k_log_scale).clamp(0.5, 10.0)
        s_v = torch.exp(self.v_log_scale).clamp(0.5, 10.0)

        q_ball = self.ball.exp0(s_q * q_tan)  # [B,1,D]
        k_ball = self.ball.exp0(s_k * k_tan)  # [N,D]
        v_ball = self.ball.exp0(s_v * v_tan)  # [N,D]

        # distance-based scores
        q_rep = q_ball.repeat(1, k_ball.shape[0], 1)  # [B,N,D]
        k_rep = k_ball.unsqueeze(0).expand(B, -1, -1)  # [B,N,D]
        d = self.ball.dist(q_rep, k_rep).squeeze(-1)  # [B,N]
        scores = -(d ** 2) if self.use_sqdist else -d  # [B,N]

        # numeric stabilization: subtract max → z-score → temperature → clamp
        scores = scores - scores.max(dim=-1, keepdim=True).values
        mu = scores.mean(dim=-1, keepdim=True)
        sigma = scores.std(dim=-1, keepdim=True).clamp_min(1e-6)
        scores = (scores - mu) / sigma

        tau = torch.exp(self.log_tau).clamp(0.05, 50.0)
        scores = (tau * scores).clamp(-20.0, 20.0)

        attn = torch.softmax(scores, dim=-1)  # [B,N]
        attn = self.dropout(attn)

        # Fréchet weighted sum in the tangent space of each q
        v_ball_exp = v_ball.unsqueeze(0).expand(B, -1, -1)  # [B,N,D]
        q_center = q_ball.squeeze(1)  # [B,D]
        log_q_v = self.ball.log_p(q_center.unsqueeze(1), v_ball_exp)  # [B,N,D]
        z_tan_q = torch.bmm(attn.unsqueeze(1), log_q_v).squeeze(1)  # [B,D]
        z_ball = self.ball.exp_p(q_center, z_tan_q)  # [B,D]

        # output back to the (origin) tangent space
        out_tan0 = self.ball.log0(z_ball).unsqueeze(1)  # [B,1,D]

        # Key: gated residual with original features, without detach, to ensure early gradients
        q_tan0 = self.ball.log0(q_ball)  # [B,1,D]
        gate = torch.sigmoid(self.alpha_skip)  # (0,1)
        out_tan0 = out_tan0 + gate * q_tan0

        return out_tan0 if not squeeze else out_tan0.squeeze(1)


class MobiusLinear(nn.Module):
    def __init__(self, in_dim, out_dim, c=1.0, bias=True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.lin = nn.Linear(in_dim, out_dim, bias=bias)
        # small-scale init to avoid hitting the boundary at start
        nn.init.kaiming_uniform_(self.lin.weight, a=0.2)
        if bias:
            nn.init.zeros_(self.lin.bias)

    def forward(self, x_ball):
        # x_ball ∈ 𝔻_c
        x_ball = self.ball._proj_with_margin(x_ball)
        x_tan = self.ball.log0(x_ball)
        # control tangent-norm to prevent explosion
        x_tan = torch.clamp(x_tan, -20.0, 20.0)
        y_tan = self.lin(x_tan)
        y_tan = torch.clamp(y_tan, -20.0, 20.0)
        y_ball = self.ball.exp0(y_tan)
        return y_ball


class HypNNBlock(nn.Module):
    """
    Two MobiusLinear layers; do LN+ReLU in the tangent space; output a tangent vector for linear heads.
    """

    def __init__(self, feat_dim: int, hidden: int = 512, c: float = 1.0, dropout: float = 0.1):
        super().__init__()
        self.ball = PoincareBall(c)
        self.fc1 = MobiusLinear(feat_dim, hidden, c=c)
        self.fc2 = MobiusLinear(hidden, feat_dim, c=c)
        self.ln1 = nn.LayerNorm(hidden)
        self.ln2 = nn.LayerNorm(feat_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        # (1) Light LN on Euclidean backbone features before exp0 to avoid huge vectors
        x = nn.functional.layer_norm(x, x.shape[-1:])
        x_ball = self.ball.exp0(x)

        # (2) first layer
        h1_ball = self.fc1(x_ball)
        h1_tan = self.ball.log0(h1_ball)
        h1_tan = self.ln1(h1_tan)
        h1_tan = torch.clamp(h1_tan, -20.0, 20.0)
        h1_tan = F.relu(h1_tan)
        h1_ball = self.ball.exp0(h1_tan)
        h1_ball = self.dropout(h1_ball)

        # (3) second layer
        h2_ball = self.fc2(h1_ball)
        out_tan = self.ball.log0(h2_ball)
        out_tan = self.ln2(out_tan)
        out_tan = torch.clamp(out_tan, -20.0, 20.0)

        out = out_tan if not squeeze else out_tan.squeeze(1)
        return out


class UHopBlock(nn.Module):

    def __init__(
        self,
        feat_dim: int,
        n_mem: int,
        n_heads: int = 4,
        dropout: float = 0.1,
        mode: str = "softmax",
        kernel: str = "lin",
        scale: float | None = None,
    ):
        super().__init__()
        self.hop = LearnableHopfield(
            d_model=feat_dim,
            n_heads=n_heads,
            d_keys=None,
            d_values=None,
            mix=True,
            update_steps=1,
            dropout=dropout,
            mode=mode,
            kernel=kernel,
            scale=scale,
        )

        self.memory = nn.Parameter(torch.randn(1, n_mem, feat_dim) * 0.02)
        self.ln = nn.LayerNorm(feat_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,D] or [B,1,D]
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        B, L, D = x.shape  # L=1
        mem = self.memory.expand(B, -1, -1)          # [B, n_mem, D]

        out = self.hop(x, mem)                       # [B,1,D]
        out = self.ln(out)

        return out.squeeze(1) if squeeze else out

    def kernel_forward(self, Y: torch.Tensor) -> torch.Tensor:

        if Y.dim() == 2:
            Y = Y.unsqueeze(0)   # [1,N,D]
        return self.hop.uniform_forward(Y)  # LearnableHopfield.uniform_forward



# Mapping from CUB coarse labels to super-classes
# coarse_to_super[fine_label] = super_class_label
coarse_to_super = {
    # “Animate” super-class
    0: 0,  # aquatic_mammals         → aquatic
    1: 0,  # fish                    → aquatic

    7: 1,  # insects                 → small
    13: 1,  # non_insect_invertebrates → small
    16: 1,  # small_mammals           → small

    8: 2,  # large_carnivores        → large
    11: 2,  # large_omnivores         → large
    12: 2,  # medium_mammals          → large
    14: 2,  # people                  → large
    15: 2,  # reptiles                → large

    # “Plants” super-class
    2: 3,  # flowers                 → plants
    4: 3,  # fruits_and_vegetables   → plants
    17: 3,  # trees                   → plants

    # “Household & Furniture” super-class
    3: 4,  # food_containers         → home_items
    5: 4,  # household_electrical_devices → home_items
    6: 4,  # household_furniture     → home_items

    # “Vehicles” super-class
    9: 5,  # large_man_made_outdoor_things → vehicles
    18: 5,  # vehicles_1              → vehicles
    19: 5,  # vehicles_2              → vehicles

    # “Natural Scenes” super-class
    10: 6  # large_natural_outdoor_scenes → natural_scenes
}

# Mapping from super-classes to top-level classes
# super_to_top[super_class_label] = top_class_label
super_to_top = {
    0: 0,  # aquatic (aquatic mammals & fish)          → animals
    1: 0,  # small land/air animals (insects, etc.)    → animals
    2: 0,  # large land vertebrates & humans           → animals

    3: 1,  # plants (flowers, fruits & trees)          → plants_and_nature
    6: 1,  # natural_scenes                            → plants_and_nature

    4: 2,  # household & furniture                     → man_made_objects
    5: 2,  # vehicles                                  → man_made_objects
}


# CIFAR-100 → (fine, coarse, super_coarse)
class ThreeLevelCIFAR100(datasets.CIFAR100):
    def __init__(self, root, train, transform, coarse_labels, download=False):
        super().__init__(root, train=train, transform=transform, download=download)
        assert coarse_labels is not None
        self.coarse_labels = coarse_labels

    def __getitem__(self, idx):
        img, fine = super().__getitem__(idx)
        coarse = self.coarse_labels[idx]

        super_coarse = coarse_to_super[coarse]
        return img, fine, coarse, super_coarse


class TopLevelCIFAR100(ThreeLevelCIFAR100):
    def __getitem__(self, idx):
        img, fine, coarse, super_coarse = super().__getitem__(idx)
        top = super_to_top[super_coarse]
        return img, fine, coarse, super_coarse, top


class HierarchicalCNN(nn.Module):
    def __init__(self, baseline: str = "HAMN", hyper_c=1.0, clip_r=0.9, lr=1.0, quantity_n: int = 1, tau: float = 1.0):
        """
        baseline: 'HAMN' | 'HypAttn' | 'HypNN' | 'MHN_Euc'
        """
        super().__init__()
        self.baseline = baseline
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()
        feat_dim = 512
        quantity_fine = int(100 * quantity_n)

        if baseline == "HAMN":
            # Hyperbolic Hopfield (HAMN) module implemented earlier
            self.block = Hyper_HopfieldLayer(
                input_size=feat_dim,
                hidden_size=feat_dim,
                output_size=feat_dim,
                num_heads=4,
                hyper_c=hyper_c,
                clip_r=1.0 / math.sqrt(hyper_c) * clip_r,
                lr=lr,
                association_activation='relu',
                dropout=0.2,
                quantity=quantity_fine,
                batch_first=True,
                train_c=False,
                input_as_hyper=False,
                out_as_hyper=False
            )
        elif baseline == "HypAttn":
            # Strong baseline 1: Hyperbolic Attention
            self.block = HyperbolicAttentionLayer(
                feat_dim=feat_dim,
                n_mem=quantity_fine,
                c=hyper_c,
                tau_init=tau,
                dropout=0.1
            )
        elif baseline == "HypNN":
            # Strong baseline 2: Lightweight Hyperbolic NN block
            self.block = HypNNBlock(
                feat_dim=feat_dim,
                hidden=512,
                c=hyper_c,
                dropout=0.1
            )
        elif baseline == "MHN_Euc":
            # Euclidean MHN/Hopfield reference
            self.block = HopfieldLayer(
                input_size=feat_dim,
                hidden_size=feat_dim,
                output_size=feat_dim,  
                num_heads=8,
                scaling=0.01,
                dropout=0.2,
                association_activation='relu',
                quantity=quantity_fine,
                batch_first=True,
            )
        elif baseline == "UHop":
            # U-Hop baseline：LearnableHopfield
            self.block = UHopBlock(
                feat_dim=feat_dim,
                n_mem=quantity_fine, 
                n_heads=4,
                dropout=0.1,
                mode="softmax",
                kernel="lin",
                scale=None,
            )
        else:
            raise ValueError(f"Unknown baseline: {baseline}")

        # Multi-head classifiers (top/super/coarse/fine), unchanged
        self.head_top = nn.Linear(feat_dim, 3)
        self.head_super_coarse = nn.Linear(feat_dim, 7)
        self.head_coarse = nn.Linear(feat_dim, 20)
        self.head_fine = nn.Linear(feat_dim, 100)

    def forward(self, x):
        feats = self.backbone(x)  # [B, 512]
        feats = feats.unsqueeze(1)  # [B,1,512] unify interface

        # All blocks return [B,1,512] (tangent vector)
        feats_out = self.block(feats).squeeze(1)  # [B,512]

        return (
            self.head_top(feats_out),
            self.head_super_coarse(feats_out),
            self.head_coarse(feats_out),
            self.head_fine(feats_out)
        )


def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    prog = tqdm(train_loader, desc="Training", unit="batch")

    model_device = next(model.parameters()).device  # the actual device of the model

    for imgs, fine, coarse, super_coarse, top in prog:
        # Move inputs to the *model device* for consistency
        imgs = imgs.to(model_device, non_blocking=True)
        fine = fine.to(model_device, non_blocking=True)
        coarse = coarse.to(model_device, non_blocking=True)
        super_coarse = super_coarse.to(model_device, non_blocking=True)
        top = top.to(model_device, non_blocking=True)

        # forward / compute loss
        top_out, super_coarse_out, coarse_out, fine_out = model(imgs)
        loss_top = criterion(top_out, top)
        loss_super_coarse = criterion(super_coarse_out, super_coarse)
        loss_coarse = criterion(coarse_out, coarse)
        loss_fine = criterion(fine_out, fine)
        loss = loss_top + loss_super_coarse + loss_coarse + loss_fine

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()
        prog.set_postfix(loss=total_loss / (prog.n + 1))

    avg = total_loss / len(train_loader)
    print(f"Train Loss: {avg:.4f}")
    return avg


# 5. Testing loop
def test(model, test_loader, device):
    model.eval()
    correct_top = correct_super = correct_c = correct_f = total = 0
    prog = tqdm(test_loader, desc="Testing", unit="batch")

    with torch.no_grad():
        for imgs, fine, coarse, super_coarse, top in prog:
            imgs = imgs.to(device)
            fine = fine.to(device)
            coarse = coarse.to(device)
            super_coarse = super_coarse.to(device)
            top = top.to(device)

            top_out, super_out, coarse_out, fine_out = model(imgs)

            pred_top = top_out.argmax(dim=1)
            pred_sup = super_out.argmax(dim=1)
            pred_c = coarse_out.argmax(dim=1)
            pred_f = fine_out.argmax(dim=1)

            correct_top += (pred_top == top).sum().item()
            correct_super += (pred_sup == super_coarse).sum().item()
            correct_c += (pred_c == coarse).sum().item()
            correct_f += (pred_f == fine).sum().item()
            total += fine.size(0)

            prog.set_postfix(
                top_acc=f"{correct_top / total:.4f}",
                super_acc=f"{correct_super / total:.4f}",
                coarse_acc=f"{correct_c / total:.4f}",
                fine_acc=f"{correct_f / total:.4f}",
            )

    print(
        f"Acc | Top: {correct_top / total:.4f}, "
        f"Super: {correct_super / total:.4f}, "
        f"Coarse: {correct_c / total:.4f}, "
        f"Fine: {correct_f / total:.4f}"
    )
    return correct_top / total, correct_super / total, correct_c / total, correct_f / total

def learn_kernel_uhop(model, train_loader, device, kernel_epoch: int = 50):

    assert isinstance(model.block, UHopBlock), "learn_kernel_uhop in UHopBlock"

    kernel_params = list(model.block.hop.kernel.parameters())
    opt = torch.optim.SGD(kernel_params, lr=0.1, momentum=0.9)

    model.train()
    for epoch in tqdm(range(kernel_epoch), desc="UHop kernel pretrain (epochs)", unit="epoch"):
        unif_losses = []
        prog = tqdm(train_loader, desc=f"kernel epoch {epoch+1}/{kernel_epoch}", unit="batch", leave=False)
        for imgs, fine, coarse, super_coarse, top in prog:
            imgs = imgs.to(device)

            opt.zero_grad()

            with torch.no_grad():
                feats = model.backbone(imgs)   # [B, 512]
            feats = feats.detach()             

   
            Y = feats.unsqueeze(0)            # [1, B, 512]
            memory = model.block.kernel_forward(Y)  # [1, B, 512]

            s = memory.squeeze(0)             # [B, 512]
            s = F.normalize(s, dim=-1)        
            loss = uniform_loss(s)           

            loss.backward()
            opt.step()
            unif_losses.append(loss.item())

        print(f"[UHop kernel pretrain] epoch {epoch+1}/{kernel_epoch}, "
              f"uniform loss = {np.mean(unif_losses):.4f}")



def test_with_structure_metrics(model, test_loader, device):
    """
    TopLevelCIFAR100:
      1) (flat_top, flat_super, flat_coarse, flat_fine)
      2) silhouette score
      3) cophenetic correlation
      4) ultrametricity ratio

    Returns:
        flat_top, flat_super, flat_coarse, flat_fine,
        sil_top, sil_super, sil_coarse, sil_fine,
        coph_corr, ultra_ratio
    """
    model.eval()
    feats_list, fine_list, coarse_list, super_list, top_list = [], [], [], [], []
    correct_top = correct_super = correct_c = correct_f = total = 0

    with torch.no_grad():
        for imgs, fine, coarse, super_coarse, top in tqdm(test_loader, desc="Evaluating", unit="batch"):
            imgs = imgs.to(device)
            fine = fine.to(device)
            coarse = coarse.to(device)
            super_coarse = super_coarse.to(device)
            top = top.to(device)

            # forward
            top_out, super_out, coarse_out, fine_out = model(imgs)

            # predict
            pred_top = top_out.argmax(dim=1)
            pred_super = super_out.argmax(dim=1)
            pred_c = coarse_out.argmax(dim=1)
            pred_f = fine_out.argmax(dim=1)

            # accumulate flat accuracy
            correct_top += (pred_top == top).sum().item()
            correct_super += (pred_super == super_coarse).sum().item()
            correct_c += (pred_c == coarse).sum().item()
            correct_f += (pred_f == fine).sum().item()
            total += fine.size(0)

            # collect embeddings & labels for structure metrics
            feats = model.backbone(imgs)
            feats_list.append(feats.cpu().numpy())
            fine_list.extend(fine.cpu().numpy())
            coarse_list.extend(coarse.cpu().numpy())
            super_list.extend(super_coarse.cpu().numpy())
            top_list.extend(top.cpu().numpy())

    # flat accuracies
    flat_top = correct_top / total
    flat_super = correct_super / total
    flat_coarse = correct_c / total
    flat_fine = correct_f / total

    print(f"Flat Acc | Top: {flat_top:.4f}, Super: {flat_super:.4f}, Coarse: {flat_coarse:.4f}, Fine: {flat_fine:.4f}")

    # prepare for structure metrics
    feats = np.concatenate(feats_list, axis=0)
    labels_top = np.array(top_list)
    labels_super = np.array(super_list)
    labels_coarse = np.array(coarse_list)
    labels_fine = np.array(fine_list)

    # silhouette
    sil_top = silhouette_score(feats, labels_top)
    sil_super = silhouette_score(feats, labels_super)
    sil_coarse = silhouette_score(feats, labels_coarse)
    sil_fine = silhouette_score(feats, labels_fine)
    print(f"Silhouette | Top: {sil_top:.4f}, Super: {sil_super:.4f}, Coarse: {sil_coarse:.4f}, Fine: {sil_fine:.4f}")

    # cophenetic
    Z = linkage(feats, method="average")
    coph_corr, _ = cophenet(Z, pdist(feats))
    print(f"Cophenetic Correlation: {coph_corr:.4f}")

    # ultrametricity
    D = squareform(pdist(feats))
    N = D.shape[0]
    viol = 0
    num_samples = min(10000, N * (N - 1) * (N - 2) // 6)
    rng = np.random.default_rng(42)
    for _ in range(num_samples):
        i, j, k = rng.choice(N, size=3, replace=False)
        if D[i, j] > max(D[i, k], D[j, k]):
            viol += 1
    ultra_ratio = 1.0 - viol / num_samples
    print(f"Ultrametricity (1 - violation_rate): {ultra_ratio:.4f}")

    return flat_top, flat_super, flat_coarse, flat_fine, sil_top, sil_super, sil_coarse, sil_fine, coph_corr, ultra_ratio


# ─── 4. main ───────────────────────────────────────────────────────
def main(baseline="HAMN", hyper_c=1.0, clip_r=0.9, lr=1.0, quantity_n: int = 1, tau: float = 1.0,
         checkpoint_path: str = "checkpoints.pth"):
    # 1) Data loading with both fine and coarse labels
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])
    train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_ds = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    # raw coarse_labels
    train_raw = pickle.load(open('./data/cifar-100-python/train', 'rb'), encoding='bytes')
    test_raw = pickle.load(open('./data/cifar-100-python/test', 'rb'), encoding='bytes')
    train_coarse = train_raw[b'coarse_labels']
    test_coarse = test_raw[b'coarse_labels']

    # Dataset/Loader
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    train_ds = TopLevelCIFAR100('./data', train=True, transform=transform, coarse_labels=train_coarse)
    test_ds = TopLevelCIFAR100('./data', train=False, transform=transform, coarse_labels=test_coarse)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)

    model = HierarchicalCNN(
        baseline=baseline,
        hyper_c=hyper_c,
        clip_r=clip_r,
        lr=lr,
        quantity_n=quantity_n,
        tau=tau
    ).to(device)


    if baseline == "UHop":

        kernel_ckpt = os.path.join(os.path.dirname(checkpoint_path),
                                   "uhop_kernel_only.pth")

        if os.path.isfile(kernel_ckpt):

            print("=> loading UHop kernel checkpoint", kernel_ckpt)
            model.load_state_dict(torch.load(kernel_ckpt, map_location=device))
        else:

            print("=> UHop baseline: pretraining kernel with uniform loss (U-Hop style)")
            learn_kernel_uhop(model, train_loader, device, kernel_epoch=50)

            os.makedirs(os.path.dirname(kernel_ckpt), exist_ok=True)
            torch.save(model.state_dict(), kernel_ckpt)
            print("=> saved UHop kernel checkpoint to", kernel_ckpt)


    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint", checkpoint_path)
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    def _split_decay(named_params):
        decay, no_decay = [], []
        for n, p in named_params:
            if not p.requires_grad:
                continue
            # do not apply weight decay to biases and (Layer/Batch)Norm params
            if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower() or "norm" in n.lower():
                no_decay.append(p)
            else:
                decay.append(p)
        return decay, no_decay

    back_named = list(model.backbone.named_parameters())
    head_named = (
        list(model.head_top.named_parameters()) +
        list(model.head_super_coarse.named_parameters()) +
        list(model.head_coarse.named_parameters()) +
        list(model.head_fine.named_parameters())
    )
    # HypAttn / HypNN / MHN_Euc / HAMN all expose named_parameters()
    block_named = list(model.block.named_parameters())

    decay_base, nodecay_base = _split_decay(back_named + head_named)
    decay_block, nodecay_block = _split_decay(block_named)

    optimizer = optim.AdamW(
        [
            {"params": decay_base,   "lr": 1e-3, "weight_decay": 5e-4},
            {"params": nodecay_base, "lr": 1e-3, "weight_decay": 0.0},
            {"params": decay_block,  "lr": 3e-4, "weight_decay": 1e-4},
            {"params": nodecay_block,"lr": 3e-4, "weight_decay": 0.0},
        ],
        betas=(0.9, 0.999), eps=1e-8
    )

    criterion = nn.CrossEntropyLoss()

    best_acc = -1.0

    for epoch in range(1, 30):
        print(f"Epoch {epoch}")
        train_epoch(model, train_loader, optimizer, criterion, device)

        flat_top, flat_super, flat_coarse, flat_fine = test(model, test_loader, device)
        best = flat_top + flat_super + flat_coarse + flat_fine

        if best > best_acc:
            best_acc = best
            torch.save(model.state_dict(), checkpoint_path)
            print(f"=> New best flat_top={best_acc:.4f}, saved to {checkpoint_path}")

    print("=> loading best checkpoint", checkpoint_path)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    flat_top, flat_super, flat_coarse, flat_fine, sil_top, sil_super, sil_coarse, sil_fine, coph_corr, ultra_ratio = test_with_structure_metrics(
        model, test_loader, device)

    return flat_top, flat_super, flat_coarse, flat_fine, sil_top, sil_super, sil_coarse, sil_fine, coph_corr, ultra_ratio, test_loader


if __name__ == '__main__':
    # Run multiple baselines; edit this list as needed
    baselines = ["MHN_Euc", "UHop", "HAMN", "HypAttn", "HypNN"]
    # baselines = ["MHN_Euc", "UHop", "HAMN", "HypAttn", "HypNN"]

    best_params = {
        "hyper_c":    [0.8],
        "clip_r":     [0.9],
        "lr":         [1.0],
        "quantity_n": [2],
        "tau":        [1.0],
    }

    # Root results dir
    os.makedirs("hyper_deem4_result", exist_ok=True)

    # (Optional) collect summaries for all baselines
    all_summaries = {}

    for baseline in baselines:
        print(f"\n========== Running baseline: {baseline} ==========")

        best_score = -1.0
        best_checkpoint = ""

        # Per-baseline subdir to avoid collisions
        ckpt_dir = f"hopfield_deem4_result/{baseline}/best_checkpoints"
        os.makedirs(ckpt_dir, exist_ok=True)

        metrics = []
        for i in range(10):
            hyper_c    = best_params["hyper_c"][0]
            clip_r     = best_params["clip_r"][0]
            lr         = best_params["lr"][0]
            quantity_n = best_params["quantity_n"][0]
            tau        = best_params["tau"][0]

            run_ckpt = (
                f"{ckpt_dir}/"
                f"{baseline}_hc{hyper_c}_cr{clip_r}"
                f"_lr{lr}_qn{quantity_n}"
                f"_run{i + 1}_tau{tau}.pth"
            )

            (flat_top, flat_super, flat_coarse, flat_fine,
             sil_top, sil_super, sil_coarse, sil_fine,
             coph_corr, ultra_ratio, test_loader) = main(
                baseline=baseline,
                hyper_c=hyper_c,
                clip_r=clip_r,
                lr=lr,
                quantity_n=quantity_n,
                tau=tau,
                checkpoint_path=run_ckpt
            )

            metrics.append([
                flat_top,
                flat_super,
                flat_coarse,
                flat_fine,
                sil_top,
                sil_super,
                sil_coarse,
                sil_fine,
                coph_corr,
                ultra_ratio
            ])

            # Track best by flat_top (same criterion as the original)
            if flat_top > best_score:
                best_score = flat_top
                best_checkpoint = run_ckpt

        metrics = np.array(metrics)  # shape (10, 10)
        means = metrics.mean(axis=0)
        stds  = metrics.std(axis=0)

        summary = {
            "baseline": baseline,
            "param_set": best_params,
            "means": {
                "flat_top":               means[0],
                "flat_super_coarse_acc":  means[1],
                "flat_coarse_acc":        means[2],
                "flat_fine_acc":          means[3],
                "sil_top":                means[4],
                "sil_super_coarse":       means[5],
                "sil_coarse":             means[6],
                "sil_fine":               means[7],
                "coph_corr":              means[8],
                "ultra_ratio":            means[9]
            },
            "stds": {
                "flat_top":               stds[0],
                "flat_super_coarse_acc":  stds[1],
                "flat_coarse_acc":        stds[2],
                "flat_fine_acc":          stds[3],
                "sil_top":                stds[4],
                "sil_super_coarse":       stds[5],
                "sil_coarse":             stds[6],
                "sil_fine":               stds[7],
                "coph_corr":              stds[8],
                "ultra_ratio":            stds[9]
            },
            "best_checkpoint": best_checkpoint,
            "best_score_by_flat_top": best_score
        }

        all_summaries[baseline] = summary

        # Save per-baseline result file
        result_path = f"hyper_deem4_result/{baseline}/Result.txt"
        os.makedirs(os.path.dirname(result_path), exist_ok=True)
        with open(result_path, "a") as f:
            f.write("\n\nRepeat 10 runs summary:\n")
            json.dump(summary, f, indent=2)

        print(f"[{baseline}] Repeated evaluation done. Summary saved to {result_path}")

    # (Optional) combined summary
    with open("hyper_deem4_result/summary_all_baselines.json", "w") as f:
        json.dump(all_summaries, f, indent=2)
    print("All baselines finished. Combined summary saved to hyper_deem4_result/summary_all_baselines.json")
