import os
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import pickle
from tqdm import tqdm
import numpy as np
from sklearn.metrics import silhouette_score
from scipy.cluster.hierarchy import linkage, cophenet
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering, KMeans
import math,time
from scipy.stats import spearmanr


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from hflayers import HopfieldLayer
from hyper_hflayers import Hyperbolic_HopfieldLayer as Hyper_HopfieldLayer  # ensure the module path is correct

# Wrap dataset to return coarse label
from torchvision import datasets

import torch.nn.functional as F

HIER_CACHE_DIR = "cifar_feature_hierarchy"  

def enforce_min_size(labels: np.ndarray,
                     feats: np.ndarray,
                     n_clusters: int,
                     min_size: int) -> np.ndarray:

    labels = np.asarray(labels, dtype=np.int64).copy()
    N = labels.shape[0]
    assert feats.shape[0] == N
    assert N >= n_clusters * min_size, (
        f"Cannot enforce min_size={min_size} with N={N}, n_clusters={n_clusters}"
    )


    clusters = [[] for _ in range(n_clusters)]
    for i, c in enumerate(labels):
        clusters[int(c)].append(int(i))

    def cluster_sizes():
        return [len(clusters[c]) for c in range(n_clusters)]

    def get_centroid(c: int) -> np.ndarray:
       
        idxs = clusters[c]
        if len(idxs) == 0:
            return feats.mean(axis=0)
        return feats[idxs].mean(axis=0)

    
    while True:
        sizes = cluster_sizes()
        small_clusters = [c for c, s in enumerate(sizes) if s < min_size]
        donor_clusters = [c for c, s in enumerate(sizes) if s > min_size]

        if not small_clusters:
           
            break
        if not donor_clusters:

            print("[WARN] No donor cluster available while some clusters < min_size.")
            break

        for c in small_clusters:

            while len(clusters[c]) < min_size:
                sizes = cluster_sizes()
                donor_clusters = [d for d, s in enumerate(sizes) if s > min_size]
                if not donor_clusters:
                    break


                d = max(donor_clusters, key=lambda x: sizes[x])

 
                centroid_c = get_centroid(c)
                idxs_d = clusters[d]
                feats_d = feats[idxs_d]   # [nd, D]
                dists = ((feats_d - centroid_c[None, :]) ** 2).sum(axis=1)
                j_local = int(dists.argmin())
                j = idxs_d[j_local]


                clusters[d].remove(j)
                clusters[c].append(j)


                sizes = cluster_sizes()


    new_labels = np.empty_like(labels)
    for c in range(n_clusters):
        for i in clusters[c]:
            new_labels[i] = c

    return new_labels


def save_feature_hierarchy(path, fine_to_coarse_map, coarse_to_super, super_to_top):
    os.makedirs(os.path.dirname(path), exist_ok=True)


    if isinstance(fine_to_coarse_map, dict):

        fine_to_coarse_list = [int(fine_to_coarse_map[i]) for i in range(100)]
    else:
        fine_to_coarse_list = [int(x) for x in fine_to_coarse_map]

    data = {
        "fine_to_coarse": [int(x) for x in fine_to_coarse_map],
        "coarse_to_super": {str(k): int(v) for k, v in coarse_to_super.items()},
        "super_to_top": {str(k): int(v) for k, v in super_to_top.items()},
    }
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2)


def load_feature_hierarchy(path):

    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    fine_to_coarse_map = [int(x) for x in data["fine_to_coarse"]]
    coarse_to_super = {int(k): int(v) for k, v in data["coarse_to_super"].items()}
    super_to_top = {int(k): int(v) for k, v in data["super_to_top"].items()}

    return fine_to_coarse_map, coarse_to_super, super_to_top

# ---------------- Poincaré Ball ops (curvature = -c, c>0) ----------------
class PoincareBall:
    def __init__(self, c: float, eps: float = 1e-6):
        self.c = float(c)
        self.eps = eps
        self.radius = 1.0 / (self.c ** 0.5)
        self._M = 1.0 - 1e-5  # safety margin: r*(1-1e-5)

    def _proj_with_margin(self, x):
        # Force all points to lie strictly inside r*(1-1e-5) to avoid atanh(1) / λ(x) blow-ups
        r = x.norm(dim=-1, keepdim=True).clamp_min(self.eps)
        max_r = self._M * self.radius
        scale = torch.where(r > max_r, max_r / r, torch.ones_like(r))
        return x * scale

    def lambda_x(self, x):
        x2 = (x * x).sum(dim=-1, keepdim=True)
        return 2.0 / (1.0 - self.c * x2).clamp_min(1e-6)

    def mobius_add(self, x, y):
        c = self.c
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        xy = (x * y).sum(dim=-1, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        den = (1 + 2 * c * xy + (c ** 2) * x2 * y2).clamp_min(1e-6)
        out = num / den
        return self._proj_with_margin(out)

    def mobius_neg(self, x):
        return self._proj_with_margin(-x)

    def exp0(self, v):
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = (self.c ** 0.5) * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        x = coef * v
        return self._proj_with_margin(x)

    def log0(self, x):
        # clamp the argument of atanh to ensure < 1
        x = self._proj_with_margin(x)
        xnorm = x.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * xnorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = torch.atanh(arg) / ((self.c ** 0.5) * xnorm)
        return coef * x

    def exp_p(self, p, v):
        lam = self.lambda_x(p)
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = 0.5 * (self.c ** 0.5) * lam * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        delta = coef * v
        return self.mobius_add(p, delta)

    def log_p(self, p, x):
        lam = self.lambda_x(p)
        y = self.mobius_add(self.mobius_neg(p), x)
        ynorm = y.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * ynorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = (2.0 / ((self.c ** 0.5) * lam)) * (torch.atanh(arg) / ynorm)
        return coef * y

    def dist(self, x, y):
        # d_Dc(x,y) = arcosh(1 + 2c||x-y||^2 / ((1-c||x||^2)(1-c||y||^2)))
        c = self.c
        diff2 = ((x - y) ** 2).sum(dim=-1, keepdim=True)
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        num = 2 * c * diff2
        den = ((1 - c * x2) * (1 - c * y2)).clamp_min(self.eps)
        z = 1 + num / den
        return torch.acosh(z.clamp_min(1 + 1e-6))  # numerical safety


# ---------------- Hyperbolic Attention (baseline) ----------------
class HyperbolicAttentionLayer(nn.Module):
    """
    Stable + learnable scaling + gated residual hyperbolic attention:
    """
    def __init__(self, feat_dim: int, n_mem: int, c: float = 1.0,
                 tau_init: float = 5.0, dropout: float = 0.0,
                 score_center: str = "none",    # 'none' | 'min' | 'mean'
                 use_sqdist: bool = True,
                 layernorm: bool = True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.use_sqdist = use_sqdist
        self.score_center = score_center

        # Linear projections
        self.q_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.k_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.v_proj = nn.Linear(feat_dim, feat_dim, bias=False)

        # Memory vectors (tangent-space params), small-scale init
        self.mem = nn.Parameter(torch.randn(n_mem, feat_dim) * 0.02)

        # LayerNorm / Dropout
        self.ln = nn.LayerNorm(feat_dim) if layernorm else nn.Identity()
        self.dropout = nn.Dropout(dropout)

        # Learnable temperature (log-param for stability)
        self.log_tau = nn.Parameter(torch.log(torch.tensor(float(tau_init))))

        # Learnable scales
        self.q_log_scale = nn.Parameter(torch.tensor(0.0))   # exp->1.0
        self.k_log_scale = nn.Parameter(torch.tensor(1.1))   # exp->~3.0
        self.v_log_scale = nn.Parameter(torch.tensor(1.1))   # exp->~3.0

        # Residual gate
        self.alpha_skip = nn.Parameter(torch.tensor(0.0))

    def forward(self, x):
        """
        Input [B,D] or [B,1,D], return [B,1,D] (tangent space)
        """
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        B, L, D = x.shape

        x_ln = self.ln(x)  # [B,1,D]
        q_tan = self.q_proj(x_ln)        # [B,1,D]
        k_tan = self.k_proj(self.mem)    # [N,D]
        v_tan = self.v_proj(self.mem)    # [N,D]

        s_q = torch.exp(self.q_log_scale).clamp(0.5, 10.0)
        s_k = torch.exp(self.k_log_scale).clamp(0.5, 10.0)
        s_v = torch.exp(self.v_log_scale).clamp(0.5, 10.0)

        q_ball = self.ball.exp0(s_q * q_tan)  # [B,1,D]
        k_ball = self.ball.exp0(s_k * k_tan)  # [N,D]
        v_ball = self.ball.exp0(s_v * v_tan)  # [N,D]

        q_rep = q_ball.repeat(1, k_ball.shape[0], 1)         # [B,N,D]
        k_rep = k_ball.unsqueeze(0).expand(B, -1, -1)        # [B,N,D]
        d = self.ball.dist(q_rep, k_rep).squeeze(-1)         # [B,N]
        scores = -(d ** 2) if self.use_sqdist else -d

        scores = scores - scores.max(dim=-1, keepdim=True).values
        mu = scores.mean(dim=-1, keepdim=True)
        sigma = scores.std(dim=-1, keepdim=True).clamp_min(1e-6)
        scores = (scores - mu) / sigma

        tau = torch.exp(self.log_tau).clamp(0.05, 50.0)
        scores = (tau * scores).clamp(-20.0, 20.0)

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        v_ball_exp = v_ball.unsqueeze(0).expand(B, -1, -1)  # [B,N,D]
        q_center = q_ball.squeeze(1)                        # [B,D]
        log_q_v = self.ball.log_p(q_center.unsqueeze(1), v_ball_exp)  # [B,N,D]
        z_tan_q = torch.bmm(attn.unsqueeze(1), log_q_v).squeeze(1)    # [B,D]
        z_ball = self.ball.exp_p(q_center, z_tan_q)                   # [B,D]
        out_tan0 = self.ball.log0(z_ball).unsqueeze(1)                # [B,1,D]

        q_tan0 = self.ball.log0(q_ball)
        gate = torch.sigmoid(self.alpha_skip)
        out_tan0 = out_tan0 + gate * q_tan0

        return out_tan0 if not squeeze else out_tan0.squeeze(1)


class MobiusLinear(nn.Module):
    def __init__(self, in_dim, out_dim, c=1.0, bias=True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.lin = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.kaiming_uniform_(self.lin.weight, a=0.2)
        if bias:
            nn.init.zeros_(self.lin.bias)

    def forward(self, x_ball):
        x_ball = self.ball._proj_with_margin(x_ball)
        x_tan = self.ball.log0(x_ball)
        x_tan = torch.clamp(x_tan, -20.0, 20.0)
        y_tan = self.lin(x_tan)
        y_tan = torch.clamp(y_tan, -20.0, 20.0)
        y_ball = self.ball.exp0(y_tan)
        return y_ball


class HypNNBlock(nn.Module):
    """
    Two MobiusLinear layers; do LN+ReLU in the tangent space; output a tangent vector for linear heads.
    """

    def __init__(self, feat_dim: int, hidden: int = 512, c: float = 1.0, dropout: float = 0.1):
        super().__init__()
        self.ball = PoincareBall(c)
        self.fc1 = MobiusLinear(feat_dim, hidden, c=c)
        self.fc2 = MobiusLinear(hidden, feat_dim, c=c)
        self.ln1 = nn.LayerNorm(hidden)
        self.ln2 = nn.LayerNorm(feat_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        x = nn.functional.layer_norm(x, x.shape[-1:])
        x_ball = self.ball.exp0(x)

        h1_ball = self.fc1(x_ball)
        h1_tan = self.ball.log0(h1_ball)
        h1_tan = self.ln1(h1_tan)
        h1_tan = torch.clamp(h1_tan, -20.0, 20.0)
        h1_tan = F.relu(h1_tan)
        h1_ball = self.ball.exp0(h1_tan)
        h1_ball = self.dropout(h1_ball)

        h2_ball = self.fc2(h1_ball)
        out_tan = self.ball.log0(h2_ball)
        out_tan = self.ln2(out_tan)
        out_tan = torch.clamp(out_tan, -20.0, 20.0)

        out = out_tan if not squeeze else out_tan.squeeze(1)
        return out


# --------- 手工语义分层的 coarse_to_super / super_to_top（原始版本） ----------
coarse_to_super = {
    # “Animate” super-class
    0: 0,  # aquatic_mammals         → aquatic
    1: 0,  # fish                    → aquatic

    7: 1,  # insects                 → small
    13: 1,  # non_insect_invertebrates → small
    16: 1,  # small_mammals           → small

    8: 2,  # large_carnivores        → large
    11: 2,  # large_omnivores         → large
    12: 2,  # medium_mammals          → large
    14: 2,  # people                  → large
    15: 2,  # reptiles                → large

    # “Plants” super-class
    2: 3,  # flowers                 → plants
    4: 3,  # fruits_and_vegetables   → plants
    17: 3,  # trees                   → plants

    # “Household & Furniture” super-class
    3: 4,  # food_containers         → home_items
    5: 4,  # household_electrical_devices → home_items
    6: 4,  # household_furniture     → home_items

    # “Vehicles” super-class
    9: 5,  # large_man_made_outdoor_things → vehicles
    18: 5,  # vehicles_1              → vehicles
    19: 5,  # vehicles_2              → vehicles

    # “Natural Scenes” super-class
    10: 6  # large_natural_outdoor_scenes → natural_scenes
}

super_to_top = {
    0: 0,  # aquatic / small / large animals → animals
    1: 0,
    2: 0,

    3: 1,  # plants
    6: 1,  # natural scenes

    4: 2,  # home_items
    5: 2,  # vehicles
}


# --------- fine→coarse→super→top ----------
def build_feature_hierarchy_from_backbone(
    base_train_ds,
    num_fine: int = 100,
    num_coarse: int = 20,
    num_super: int = 7,
    num_top: int = 3,
    seed: int = 0,
):
    """

    return:
        fine_to_coarse: list[int]，长度 num_fine，fine_label -> coarse_id
        coarse_to_super: dict[int -> int]
        super_to_top:    dict[int -> int]
    """
    local_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    try:
        backbone = models.resnet18(pretrained=True)
        print("[Hierarchy] Using ImageNet-pretrained ResNet18 for feature clustering.")
    except Exception:
        backbone = models.resnet18(pretrained=False)
        print("[Hierarchy] WARNING: failed to load pretrained weights, using random ResNet18.")

    feat_dim = backbone.fc.in_features
    backbone.fc = nn.Identity()
    backbone.to(local_device)
    backbone.eval()

    loader = DataLoader(base_train_ds, batch_size=256, shuffle=False, num_workers=4)

    sums = np.zeros((num_fine, feat_dim), dtype=np.float32)
    counts = np.zeros(num_fine, dtype=np.int64)

    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Extracting features for hierarchy", unit="batch"):
            imgs = imgs.to(local_device)
            feats = backbone(imgs).cpu().numpy()   # [B, feat_dim]
            labels_np = labels.numpy().astype(int) # [B]
            for f, lbl in zip(feats, labels_np):
                sums[lbl] += f
                counts[lbl] += 1

    means = sums / np.maximum(counts[:, None], 1)  # [num_fine, feat_dim]

    # 2) fine(100) → coarse(20)，先 KMeans，再 enforce_min_size(>=4 fine 每 coarse)
    print("[Hierarchy] Clustering fine → coarse (KMeans + min 4 fine/cluster)...")
    kmeans_fine = KMeans(
        n_clusters=num_coarse,
        random_state=seed,
        n_init=10,
    )
    coarse_labels_init = kmeans_fine.fit_predict(means)  # [100]
    coarse_labels = enforce_min_size(
        labels=coarse_labels_init,
        feats=means,
        n_clusters=num_coarse,
        min_size=4,
    )

    fine_to_coarse = coarse_labels.tolist()


    coarse_means = np.zeros((num_coarse, feat_dim), dtype=np.float32)
    for c in range(num_coarse):
        idx = np.where(coarse_labels == c)[0]
        if len(idx) == 0:
            coarse_means[c] = means.mean(axis=0)
        else:
            coarse_means[c] = means[idx].mean(axis=0)

    # 3) coarse(20) → super(7)
    print("[Hierarchy] Clustering coarse → super (KMeans + min 2 coarse/cluster)...")
    kmeans_coarse = KMeans(
        n_clusters=num_super,
        random_state=seed,
        n_init=10,
    )
    super_labels_init = kmeans_coarse.fit_predict(coarse_means)  # [20]
    super_labels = enforce_min_size(
        labels=super_labels_init,
        feats=coarse_means,
        n_clusters=num_super,
        min_size=2,
    )
    coarse_to_super_new = {int(c): int(s) for c, s in enumerate(super_labels)}


    super_means = np.zeros((num_super, feat_dim), dtype=np.float32)
    for s in range(num_super):
        idx = np.where(super_labels == s)[0]
        if len(idx) == 0:
            super_means[s] = coarse_means.mean(axis=0)
        else:
            super_means[s] = coarse_means[idx].mean(axis=0)


    print("[Hierarchy] Clustering super → top (KMeans + min 2 super/cluster)...")
    kmeans_super = KMeans(
        n_clusters=num_top,
        random_state=seed,
        n_init=10,
    )
    top_labels_init = kmeans_super.fit_predict(super_means)  # [7]
    top_labels = enforce_min_size(
        labels=top_labels_init,
        feats=super_means,
        n_clusters=num_top,
        min_size=2,
    )
    super_to_top_new = {int(s): int(t) for s, t in enumerate(top_labels)}


    coarse_sizes = np.bincount(coarse_labels, minlength=num_coarse)
    super_sizes = np.bincount(super_labels, minlength=num_super)
    top_sizes = np.bincount(top_labels, minlength=num_top)
    print("[Hierarchy] coarse cluster sizes:", coarse_sizes.tolist())
    print("[Hierarchy] super cluster sizes:", super_sizes.tolist())
    print("[Hierarchy] top   cluster sizes:", top_sizes.tolist())

    return fine_to_coarse, coarse_to_super_new, super_to_top_new


# CIFAR-100 → (fine, coarse, super_coarse)
class ThreeLevelCIFAR100(datasets.CIFAR100):
    def __init__(self, root, train, transform, coarse_labels, download=False):
        super().__init__(root, train=train, transform=transform, download=download)
        assert coarse_labels is not None
        self.coarse_labels = coarse_labels

    def __getitem__(self, idx):
        img, fine = super().__getitem__(idx)
        coarse = self.coarse_labels[idx]
        super_coarse = coarse_to_super[coarse]
        return img, fine, coarse, super_coarse


class TopLevelCIFAR100(ThreeLevelCIFAR100):
    def __getitem__(self, idx):
        img, fine, coarse, super_coarse = super().__getitem__(idx)
        top = super_to_top[super_coarse]
        return img, fine, coarse, super_coarse, top


class HierarchicalCNN(nn.Module):
    def __init__(self, baseline: str = "HAMN", hyper_c=1.0, clip_r=0.9, lr=1.0, quantity_n: int = 1, tau: float = 1.0):
        """
        baseline: 'HAMN' | 'HypAttn' | 'HypNN' | 'MHN_Euc'
        """
        super().__init__()
        self.baseline = baseline
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()
        feat_dim = 512
        quantity_fine = int(100 * quantity_n)

        if baseline == "HAMN":
            self.block = Hyper_HopfieldLayer(
                input_size=feat_dim,
                hidden_size=feat_dim,
                output_size=feat_dim,
                num_heads=4,
                hyper_c=hyper_c,
                clip_r=1.0 / math.sqrt(hyper_c) * clip_r,
                lr=lr,
                association_activation='relu',
                dropout=0.2,
                quantity=quantity_fine,
                batch_first=True,
                train_c=False,
                input_as_hyper=False,
                out_as_hyper=False
            )
        elif baseline == "HypAttn":
            self.block = HyperbolicAttentionLayer(
                feat_dim=feat_dim,
                n_mem=quantity_fine,
                c=hyper_c,
                tau_init=tau,
                dropout=0.1
            )
        elif baseline == "HypNN":
            self.block = HypNNBlock(
                feat_dim=feat_dim,
                hidden=512,
                c=hyper_c,
                dropout=0.1
            )
        elif baseline == "MHN_Euc":
            self.block = HopfieldLayer(
                input_size=feat_dim,
                hidden_size=feat_dim,
                output_size=feat_dim,
                num_heads=8,
                scaling=0.01,
                dropout=0.2,
                association_activation='relu',
                quantity=quantity_fine,
                batch_first=True,
            )
        else:
            raise ValueError(f"Unknown baseline: {baseline}")

        # Multi-head classifiers (top/super/coarse/fine)
        self.head_top = nn.Linear(feat_dim, 3)
        self.head_super_coarse = nn.Linear(feat_dim, 7)
        self.head_coarse = nn.Linear(feat_dim, 20)
        self.head_fine = nn.Linear(feat_dim, 100)

    def forward(self, x):
        feats = self.backbone(x)  # [B, 512]
        feats = feats.unsqueeze(1)  # [B,1,512]

        feats_out = self.block(feats).squeeze(1)  # [B,512]

        return (
            self.head_top(feats_out),
            self.head_super_coarse(feats_out),
            self.head_coarse(feats_out),
            self.head_fine(feats_out)
        )


def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    prog = tqdm(train_loader, desc="Training", unit="batch")

    model_device = next(model.parameters()).device

    for imgs, fine, coarse, super_coarse, top in prog:
        imgs = imgs.to(model_device, non_blocking=True)
        fine = fine.to(model_device, non_blocking=True)
        coarse = coarse.to(model_device, non_blocking=True)
        super_coarse = super_coarse.to(model_device, non_blocking=True)
        top = top.to(model_device, non_blocking=True)

        top_out, super_coarse_out, coarse_out, fine_out = model(imgs)
        loss_top = criterion(top_out, top)
        loss_super_coarse = criterion(super_coarse_out, super_coarse)
        loss_coarse = criterion(coarse_out, coarse)
        loss_fine = criterion(fine_out, fine)
        loss = loss_top + loss_super_coarse + loss_coarse + loss_fine

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()
        prog.set_postfix(loss=total_loss / (prog.n + 1))

    avg = total_loss / len(train_loader)
    print(f"Train Loss: {avg:.4f}")
    return avg


def test(model, test_loader, device):
    model.eval()
    correct_top = correct_super = correct_c = correct_f = total = 0
    prog = tqdm(test_loader, desc="Testing", unit="batch")

    with torch.no_grad():
        for imgs, fine, coarse, super_coarse, top in prog:
            imgs = imgs.to(device)
            fine = fine.to(device)
            coarse = coarse.to(device)
            super_coarse = super_coarse.to(device)
            top = top.to(device)

            top_out, super_out, coarse_out, fine_out = model(imgs)

            pred_top = top_out.argmax(dim=1)
            pred_sup = super_out.argmax(dim=1)
            pred_c = coarse_out.argmax(dim=1)
            pred_f = fine_out.argmax(dim=1)

            correct_top += (pred_top == top).sum().item()
            correct_super += (pred_sup == super_coarse).sum().item()
            correct_c += (pred_c == coarse).sum().item()
            correct_f += (pred_f == fine).sum().item()
            total += fine.size(0)

            prog.set_postfix(
                top_acc=f"{correct_top / total:.4f}",
                super_acc=f"{correct_super / total:.4f}",
                coarse_acc=f"{correct_c / total:.4f}",
                fine_acc=f"{correct_f / total:.4f}",
            )

    print(
        f"Acc | Top: {correct_top / total:.4f}, "
        f"Super: {correct_super / total:.4f}, "
        f"Coarse: {correct_c / total:.4f}, "
        f"Fine: {correct_f / total:.4f}"
    )
    return correct_top / total, correct_super / total, correct_c / total, correct_f / total


def test_with_structure_metrics(model, test_loader, device):
    """
    TopLevelCIFAR100:
      1) (flat_top, flat_super, flat_coarse, flat_fine)
      2) silhouette score
      3) cophenetic correlation
      4) ultrametricity ratio
    """
    model.eval()
    feats_list, fine_list, coarse_list, super_list, top_list = [], [], [], [], []
    correct_top = correct_super = correct_c = correct_f = total = 0

    with torch.no_grad():
        for imgs, fine, coarse, super_coarse, top in tqdm(test_loader, desc="Evaluating", unit="batch"):
            imgs = imgs.to(device)
            fine = fine.to(device)
            coarse = coarse.to(device)
            super_coarse = super_coarse.to(device)
            top = top.to(device)

            top_out, super_out, coarse_out, fine_out = model(imgs)

            pred_top = top_out.argmax(dim=1)
            pred_super = super_out.argmax(dim=1)
            pred_c = coarse_out.argmax(dim=1)
            pred_f = fine_out.argmax(dim=1)

            correct_top += (pred_top == top).sum().item()
            correct_super += (pred_super == super_coarse).sum().item()
            correct_c += (pred_c == coarse).sum().item()
            correct_f += (pred_f == fine).sum().item()
            total += fine.size(0)

            feats = model.backbone(imgs)
            feats_list.append(feats.cpu().numpy())
            fine_list.extend(fine.cpu().numpy())
            coarse_list.extend(coarse.cpu().numpy())
            super_list.extend(super_coarse.cpu().numpy())
            top_list.extend(top.cpu().numpy())

    flat_top = correct_top / total
    flat_super = correct_super / total
    flat_coarse = correct_c / total
    flat_fine = correct_f / total

    print(f"Flat Acc | Top: {flat_top:.4f}, Super: {flat_super:.4f}, Coarse: {flat_coarse:.4f}, Fine: {flat_fine:.4f}")

    feats = np.concatenate(feats_list, axis=0)
    labels_top = np.array(top_list)
    labels_super = np.array(super_list)
    labels_coarse = np.array(coarse_list)
    labels_fine = np.array(fine_list)

    sil_top = silhouette_score(feats, labels_top)
    sil_super = silhouette_score(feats, labels_super)
    sil_coarse = silhouette_score(feats, labels_coarse)
    sil_fine = silhouette_score(feats, labels_fine)
    print(f"Silhouette | Top: {sil_top:.4f}, Super: {sil_super:.4f}, Coarse: {sil_coarse:.4f}, Fine: {sil_fine:.4f}")

    Z = linkage(feats, method="average")
    coph_corr, _ = cophenet(Z, pdist(feats))
    print(f"Cophenetic Correlation: {coph_corr:.4f}")

    D = squareform(pdist(feats))
    N = D.shape[0]
    viol = 0
    num_samples = min(10000, N * (N - 1) * (N - 2) // 6)
    rng = np.random.default_rng(42)
    for _ in range(num_samples):
        i, j, k = rng.choice(N, size=3, replace=False)
        if D[i, j] > max(D[i, k], D[j, k]):
            viol += 1
    ultra_ratio = 1.0 - viol / num_samples
    print(f"Ultrametricity (1 - violation_rate): {ultra_ratio:.4f}")

    return flat_top, flat_super, flat_coarse, flat_fine, sil_top, sil_super, sil_coarse, sil_fine, coph_corr, ultra_ratio, test_loader


# ─── main ───────────────────────────────────────────────────────
def main(baseline="HAMN",
         hyper_c=1.0, clip_r=0.9, lr=1.0, quantity_n: int = 1, tau: float = 1.0,
         checkpoint_path: str = "checkpoints.pth",
         hierarchy_mode: str = "feature",   # "manual" | "feature"
         hierarchy_seed: int = 0):

    global coarse_to_super, super_to_top


    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])


    base_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    base_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)


    if hierarchy_mode == "manual":

        train_raw = pickle.load(open('./data/cifar-100-python/train', 'rb'), encoding='bytes')
        test_raw = pickle.load(open('./data/cifar-100-python/test', 'rb'), encoding='bytes')
        train_coarse = train_raw[b'coarse_labels']
        test_coarse = test_raw[b'coarse_labels']
        print("[Hierarchy] Using MANUAL CIFAR hierarchy (official coarse + semantic grouping).")

    elif hierarchy_mode == "feature":
        cache_name = f"feature_hier_seed{hierarchy_seed}.json"
        cache_path = os.path.join(HIER_CACHE_DIR, cache_name)

        if os.path.isfile(cache_path):
            print(f"[Hierarchy] Loading FEATURE-DRIVEN hierarchy from cache: {cache_path}")
            fine_to_coarse_map, coarse_to_super, super_to_top = load_feature_hierarchy(cache_path)
        else:
            print("[Hierarchy] Building FEATURE-DRIVEN hierarchy from backbone features...")
            fine_to_coarse_map, coarse_to_super_new, super_to_top_new = build_feature_hierarchy_from_backbone(
                base_train_ds=base_train,
                num_fine=100,
                num_coarse=20,
                num_super=7,
                num_top=3,
                seed=hierarchy_seed,
            )
            coarse_to_super = coarse_to_super_new
            super_to_top = super_to_top_new
    
            save_feature_hierarchy(cache_path, fine_to_coarse_map, coarse_to_super, super_to_top)
            print(f"[Hierarchy] Done. Saved feature-driven hierarchy to {cache_path}")

        train_coarse = [fine_to_coarse_map[int(y)] for y in base_train.targets]
        test_coarse  = [fine_to_coarse_map[int(y)] for y in base_test.targets]

    else:
        raise ValueError(f"Unknown hierarchy_mode={hierarchy_mode}")

    # Dataset/Loader
    train_ds = TopLevelCIFAR100('./data', train=True, transform=transform, coarse_labels=train_coarse)
    test_ds = TopLevelCIFAR100('./data', train=False, transform=transform, coarse_labels=test_coarse)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)

    model = HierarchicalCNN(
        baseline=baseline,
        hyper_c=hyper_c,
        clip_r=clip_r,
        lr=lr,
        quantity_n=quantity_n,
        tau=tau
    ).to(device)

    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint", checkpoint_path)
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    def _split_decay(named_params):
        decay, no_decay = [], []
        for n, p in named_params:
            if not p.requires_grad:
                continue
            if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower() or "norm" in n.lower():
                no_decay.append(p)
            else:
                decay.append(p)
        return decay, no_decay

    back_named = list(model.backbone.named_parameters())
    head_named = (
        list(model.head_top.named_parameters()) +
        list(model.head_super_coarse.named_parameters()) +
        list(model.head_coarse.named_parameters()) +
        list(model.head_fine.named_parameters())
    )
    block_named = list(model.block.named_parameters())

    decay_base, nodecay_base = _split_decay(back_named + head_named)
    decay_block, nodecay_block = _split_decay(block_named)

    optimizer = optim.AdamW(
        [
            {"params": decay_base,   "lr": 1e-3, "weight_decay": 5e-4},
            {"params": nodecay_base, "lr": 1e-3, "weight_decay": 0.0},
            {"params": decay_block,  "lr": 3e-4, "weight_decay": 1e-4},
            {"params": nodecay_block,"lr": 3e-4, "weight_decay": 0.0},
        ],
        betas=(0.9, 0.999), eps=1e-8
    )

    criterion = nn.CrossEntropyLoss()

    best_acc = -1.0

    for epoch in range(1, 30):
        print(f"Epoch {epoch}")
        train_epoch(model, train_loader, optimizer, criterion, device)

        flat_top, flat_super, flat_coarse, flat_fine = test(model, test_loader, device)
        best = flat_top + flat_super + flat_coarse + flat_fine

        if best > best_acc:
            best_acc = best
            torch.save(model.state_dict(), checkpoint_path)
            print(f"=> New best sum(flat_acc)={best_acc:.4f}, saved to {checkpoint_path}")

    print("=> loading best checkpoint", checkpoint_path)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    flat_top, flat_super, flat_coarse, flat_fine, sil_top, sil_super, sil_coarse, sil_fine, coph_corr, ultra_ratio, test_loader = \
        test_with_structure_metrics(model, test_loader, device)

    return flat_top, flat_super, flat_coarse, flat_fine, sil_top, sil_super, sil_coarse, sil_fine, coph_corr, ultra_ratio, test_loader


if __name__ == '__main__':

    baselines = ["HAMN"]  #  ["HAMN", "HypAttn", "HypNN", "MHN_Euc"]

    best_params = {
        "hyper_c":    [0.8],
        "clip_r":     [0.9],
        "lr":         [1.0],
        "quantity_n": [2],
        "tau":        [4.0],
    }

    os.makedirs("hyper_deem4_result_other4", exist_ok=True)
    all_summaries = {}

    for baseline in baselines:
        print(f"\n========== Running baseline: {baseline} ==========")

        best_score = -1.0
        best_checkpoint = ""

        ckpt_dir = f"hyper_deem4_result_other4/{baseline}/best_checkpoints"
        os.makedirs(ckpt_dir, exist_ok=True)

        metrics = []
        for i in range(10):   
            hyper_c    = best_params["hyper_c"][0]
            clip_r     = best_params["clip_r"][0]
            lr         = best_params["lr"][0]
            quantity_n = best_params["quantity_n"][0]
            tau        = best_params["tau"][0]

            run_ckpt = (
                f"{ckpt_dir}/"
                f"{baseline}_hc{hyper_c}_cr{clip_r}"
                f"_lr{lr}_qn{quantity_n}"
                f"_run{i + 1}_tau{tau}.pth"
            )

            (flat_top, flat_super, flat_coarse, flat_fine,
             sil_top, sil_super, sil_coarse, sil_fine,
             coph_corr, ultra_ratio, test_loader) = main(
                baseline=baseline,
                hyper_c=hyper_c,
                clip_r=clip_r,
                lr=lr,
                quantity_n=quantity_n,
                tau=tau,
                checkpoint_path=run_ckpt,
                hierarchy_mode="feature",    
                hierarchy_seed=0,
            )

            metrics.append([
                flat_top,
                flat_super,
                flat_coarse,
                flat_fine,
                sil_top,
                sil_super,
                sil_coarse,
                sil_fine,
                coph_corr,
                ultra_ratio
            ])

            if flat_top > best_score:
                best_score = flat_top
                best_checkpoint = run_ckpt

        metrics = np.array(metrics)
        means = metrics.mean(axis=0)
        stds  = metrics.std(axis=0)

        summary = {
            "baseline": baseline,
            "param_set": best_params,
            "means": {
                "flat_top":               means[0],
                "flat_super_coarse_acc":  means[1],
                "flat_coarse_acc":        means[2],
                "flat_fine_acc":          means[3],
                "sil_top":                means[4],
                "sil_super_coarse":       means[5],
                "sil_coarse":             means[6],
                "sil_fine":               means[7],
                "coph_corr":              means[8],
                "ultra_ratio":            means[9]
            },
            "stds": {
                "flat_top":               stds[0],
                "flat_super_coarse_acc":  stds[1],
                "flat_coarse_acc":        stds[2],
                "flat_fine_acc":          stds[3],
                "sil_top":                stds[4],
                "sil_super_coarse":       stds[5],
                "sil_coarse":             stds[6],
                "sil_fine":               stds[7],
                "coph_corr":              stds[8],
                "ultra_ratio":            stds[9]
            },
            "best_checkpoint": best_checkpoint,
            "best_score_by_flat_top": best_score
        }

        all_summaries[baseline] = summary

        result_path = f"hyper_deem4_result_other4/{baseline}/Result.txt"
        os.makedirs(os.path.dirname(result_path), exist_ok=True)
        with open(result_path, "a") as f:
            f.write("\n\nRepeat 1 runs summary (you can increase repeats):\n")
            json.dump(summary, f, indent=2)

        print(f"[{baseline}] Repeated evaluation done. Summary saved to {result_path}")


    summary_path = "hyper_deem4_result_other4/summary_all_baselines.json"


    old = []
    if os.path.exists(summary_path):
        try:
            with open(summary_path, "r", encoding="utf-8") as f:
                old = json.load(f)

                if isinstance(old, dict):
                    old = [old]
        except json.JSONDecodeError:
            old = []


    record = {
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
        "summaries": all_summaries,   
    }

    old.append(record)


    os.makedirs(os.path.dirname(summary_path), exist_ok=True)
    tmp_path = summary_path + ".tmp"
    with open(tmp_path, "w", encoding="utf-8") as f:
        json.dump(old, f, indent=2)
    os.replace(tmp_path, summary_path)

    print("All baselines finished. Appended summary to", summary_path)

