# wordnet_hypernym_experiment.py
import os
import json
import math
import random,time
from typing import List, Tuple, Dict, Optional

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib
matplotlib.use("Agg") 
import matplotlib.pyplot as plt

# ---------- WordNet ----------
import nltk
from nltk.corpus import wordnet as wn

# ---------- Optional text encoders ----------
try:
    from sentence_transformers import SentenceTransformer
    HAS_SBERT = True
except Exception:
    HAS_SBERT = False

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize as sk_normalize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Your Hopfield modules ----------
try:
    from hflayers import HopfieldLayer               # Euclidean MHN
except Exception:
    HopfieldLayer = None

try:
    from hyper_hflayers import Hyperbolic_HopfieldLayer as Hyper_HopfieldLayer  # HAMN
except Exception:
    Hyper_HopfieldLayer = None

try:
    from Uhop import LearnableHopfield,uniform_loss
except Exception:
    LearnableHopfield = None

# ===================== Poincaré ball ops =====================
class PoincareBall:
    def __init__(self, c: float, eps: float = 1e-6):
        self.c = float(c)
        self.eps = eps
        self.radius = 1.0 / (self.c ** 0.5)
        self._M = 1.0 - 1e-5

    def _proj_with_margin(self, x):
        r = x.norm(dim=-1, keepdim=True).clamp_min(self.eps)
        max_r = self._M * self.radius
        scale = torch.where(r > max_r, max_r / r, torch.ones_like(r))
        return x * scale

    def lambda_x(self, x):
        x2 = (x * x).sum(dim=-1, keepdim=True)
        return 2.0 / (1.0 - self.c * x2).clamp_min(1e-6)

    def mobius_add(self, x, y):
        c = self.c
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        xy = (x * y).sum(dim=-1, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        den = (1 + 2 * c * xy + (c ** 2) * x2 * y2).clamp_min(1e-6)
        out = num / den
        return self._proj_with_margin(out)

    def mobius_neg(self, x):
        return self._proj_with_margin(-x)

    def exp0(self, v):
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = (self.c ** 0.5) * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        x = coef * v
        return self._proj_with_margin(x)

    def log0(self, x):
        x = self._proj_with_margin(x)
        xnorm = x.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * xnorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = torch.atanh(arg) / ((self.c ** 0.5) * xnorm)
        return coef * x

    def exp_p(self, p, v):
        lam = self.lambda_x(p)
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = 0.5 * (self.c ** 0.5) * lam * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        delta = coef * v
        return self.mobius_add(p, delta)

    def log_p(self, p, x):
        lam = self.lambda_x(p)
        y = self.mobius_add(self.mobius_neg(p), x)
        ynorm = y.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * ynorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = (2.0 / ((self.c ** 0.5) * lam)) * (torch.atanh(arg) / ynorm)
        return coef * y

    def dist(self, x, y):
        c = self.c
        diff2 = ((x - y) ** 2).sum(dim=-1, keepdim=True)
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        num = 2 * c * diff2
        den = ((1 - c * x2) * (1 - c * y2)).clamp_min(self.eps)
        z = 1 + num / den
        return torch.acosh(z.clamp_min(1 + 1e-6))


# ===================== Hyperbolic baselines =====================
class HyperbolicAttentionLayer(nn.Module):
    def __init__(self, feat_dim: int, n_mem: int, c: float = 1.0,
                 tau_init: float = 1.0, dropout: float = 0.0):
        super().__init__()
        self.ball = PoincareBall(c)
        self.q_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.k_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.v_proj = nn.Linear(feat_dim, feat_dim, bias=False)
        self.mem   = nn.Parameter(torch.randn(n_mem, feat_dim) * 0.02)
        self.ln    = nn.LayerNorm(feat_dim)
        self.dropout = nn.Dropout(dropout)

        self.log_tau     = nn.Parameter(torch.log(torch.tensor(float(tau_init))))
        self.q_log_scale = nn.Parameter(torch.tensor(0.0))
        self.k_log_scale = nn.Parameter(torch.tensor(0.2))
        self.v_log_scale = nn.Parameter(torch.tensor(0.2))
        self.alpha_skip  = nn.Parameter(torch.tensor(0.0))

    @torch.no_grad()
    def set_memory(self, mem_tensor: torch.Tensor):
        assert mem_tensor.dim() == 2 and mem_tensor.size(1) == self.mem.size(1)
        if mem_tensor.size(0) != self.mem.size(0):
            self.mem = nn.Parameter(mem_tensor.clone().to(self.mem.device))
        else:
            self.mem.data = mem_tensor.to(self.mem.device)

    def forward(self, x):  # x: [B,D] or [B,1,D]
        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1); squeeze = True
        B, _, D = x.shape

        x_ln  = self.ln(x)
        q_tan = self.q_proj(x_ln)
        k_tan = self.k_proj(self.mem)
        v_tan = self.v_proj(self.mem)

        s_q = torch.exp(self.q_log_scale).clamp(0.5, 10.0)
        s_k = torch.exp(self.k_log_scale).clamp(0.5, 10.0)
        s_v = torch.exp(self.v_log_scale).clamp(0.5, 10.0)

        q_ball = self.ball.exp0(s_q * q_tan)   # [B,1,D]
        k_ball = self.ball.exp0(s_k * k_tan)   # [N,D]
        v_ball = self.ball.exp0(s_v * v_tan)   # [N,D]

        q_rep = q_ball.repeat(1, k_ball.shape[0], 1)           # [B,N,D]
        k_rep = k_ball.unsqueeze(0).expand(B, -1, -1)          # [B,N,D]
        d     = self.ball.dist(q_rep, k_rep).squeeze(-1)       # [B,N]

        scores = -(d ** 2)
        mu     = scores.mean(dim=-1, keepdim=True)
        sigma  = scores.std(dim=-1, keepdim=True).clamp_min(1e-6)
        scores = (scores - mu) / sigma

        tau    = torch.exp(self.log_tau).clamp(0.05, 10.0)
        scores = (tau * scores).clamp(-20.0, 20.0)

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        v_ball_exp = v_ball.unsqueeze(0).expand(B, -1, -1)
        q_center   = q_ball.squeeze(1)
        log_q_v    = self.ball.log_p(q_center.unsqueeze(1), v_ball_exp)  # [B,N,D]
        z_tan_q    = torch.bmm(attn.unsqueeze(1), log_q_v).squeeze(1)    # [B,D]
        z_ball     = self.ball.exp_p(q_center, z_tan_q)                   # [B,D]
        out_tan0   = self.ball.log0(z_ball).unsqueeze(1)                  # [B,1,D]

        # gated residual
        q_tan0 = self.ball.log0(q_ball)
        gate   = torch.sigmoid(self.alpha_skip)
        out_tan0 = out_tan0 + gate * q_tan0

        return out_tan0 if not squeeze else out_tan0.squeeze(1)


class MobiusLinear(nn.Module):
    def __init__(self, in_dim, out_dim, c=1.0, bias=True):
        super().__init__()
        self.ball = PoincareBall(c)
        self.lin = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.kaiming_uniform_(self.lin.weight, a=0.2)
        if bias: nn.init.zeros_(self.lin.bias)

    def forward(self, x_ball):
        x_ball = self.ball._proj_with_margin(x_ball)
        x_tan = self.ball.log0(x_ball)
        x_tan = torch.clamp(x_tan, -20.0, 20.0)
        y_tan = self.lin(x_tan)
        y_tan = torch.clamp(y_tan, -20.0, 20.0)
        y_ball = self.ball.exp0(y_tan)
        return y_ball


class HypNNBlock(nn.Module):
    def __init__(self, feat_dim: int, hidden: int = 512, c: float = 1.0, dropout: float = 0.1):
        super().__init__()
        self.ball = PoincareBall(c)
        self.fc1 = MobiusLinear(feat_dim, hidden, c=c)
        self.fc2 = MobiusLinear(hidden, feat_dim, c=c)
        self.ln1 = nn.LayerNorm(hidden)
        self.ln2 = nn.LayerNorm(feat_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):  # x: [B,D] or [B,1,D]
        if x.dim() == 2: x = x.unsqueeze(1)
        x = F.layer_norm(x, x.shape[-1:])
        x_ball = self.ball.exp0(x)
        h1_ball = self.fc1(x_ball)
        h1_tan = self.ball.log0(h1_ball)
        h1_tan = self.ln1(torch.clamp(h1_tan, -20, 20))
        h1_ball = self.ball.exp0(F.relu(h1_tan))
        h1_ball = self.dropout(h1_ball)
        h2_ball = self.fc2(h1_ball)
        out_tan = self.ball.log0(h2_ball)
        out_tan = torch.clamp(self.ln2(out_tan), -20, 20)
        return out_tan.squeeze(1)

class UHopBlock(nn.Module):
    def __init__(
        self,
        feat_dim: int,
        n_mem: int,
        n_heads: int = 4,
        dropout: float = 0.1,
        mode: str = "softmax",
        kernel: str = "lin",
        scale: Optional[float] = None,
    ):
        super().__init__()
        assert LearnableHopfield is not None, "LearnableHopfield not found. Please check Uhop.py / layers.py."

        self.hop = LearnableHopfield(
            d_model=feat_dim,
            n_heads=n_heads,
            d_keys=None,
            d_values=None,
            mix=True,
            update_steps=1,
            dropout=dropout,
            mode=mode,
            kernel=kernel,
            scale=scale,
        )

        self.memory = nn.Parameter(torch.randn(1, n_mem, feat_dim) * 0.02)
        self.ln = nn.LayerNorm(feat_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        squeeze = False
        if x.dim() == 2:
            x = x.unsqueeze(1)
            squeeze = True

        B, L, D = x.shape  # L=1
        mem = self.memory.expand(B, -1, -1)          # [B, n_mem, D]

        out = self.hop(x, mem)                       # [B,1,D]
        out = self.ln(out)

        return out.squeeze(1) if squeeze else out

    def kernel_forward(self, Y: torch.Tensor) -> torch.Tensor:

        if Y.dim() == 2:
            Y = Y.unsqueeze(0)   # [1, N, D]
        return self.hop.uniform_forward(Y)



# ===================== Text encoder =====================
class TextEncoder:
    """Return numpy array [N, d_in] for input texts."""
    def __init__(self, mode: str = "auto", max_features: int = 30000):
        self.mode = "sbert" if (mode == "sbert" and HAS_SBERT) else ("sbert" if HAS_SBERT and mode == "auto" else "tfidf")
        self.vec = None
        if self.mode == "sbert":
            self.model = SentenceTransformer("all-MiniLM-L6-v2")
        else:
            self.model = None
            self.vec = TfidfVectorizer(max_features=max_features, ngram_range=(1,2), lowercase=True)

    def fit_transform(self, texts: List[str]) -> np.ndarray:
        if self.mode == "sbert":
            embs = self.model.encode(texts, batch_size=64, show_progress_bar=True, convert_to_numpy=True)
            embs = sk_normalize(embs)  # cosine-friendly
            return embs
        X = self.vec.fit_transform(texts)
        X = X.astype(np.float32)
        return X.toarray()

    def transform(self, texts: List[str]) -> np.ndarray:
        if self.mode == "sbert":
            embs = self.model.encode(texts, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
            embs = sk_normalize(embs)
            return embs
        X = self.vec.transform(texts)
        X = X.astype(np.float32)
        return X.toarray()


# ===================== WordNet dataset builder =====================
def synset_text(s):
    lemmas = ", ".join([l.name().replace("_"," ") for l in s.lemmas()])
    gloss = s.definition()
    ex = " ".join(s.examples()) if s.examples() else ""
    return f"{lemmas}. {gloss}. {ex}".strip()

def pick_primary_hypernym(s, strategy="first"):
    hypers = s.hypernyms()
    if not hypers: return None
    if strategy == "max_depth":
        # choose the parent with maximal min_depth (closest to s)
        best = None; best_d = -1
        for h in hypers:
            d = h.min_depth()
            if d > best_d: best, best_d = h, d
        return best
    return hypers[0]

def collect_wordnet_pairs(pos="n", primary="first", min_freq=5, max_nodes=20000):
    synsets = list(wn.all_synsets(pos=pos))
    random.Random(42).shuffle(synsets)
    X, Y = [], []
    kept = 0
    for s in synsets:
        if kept >= max_nodes: break
        p = pick_primary_hypernym(s, primary)
        if p is None: continue
        X.append((s.offset(), synset_text(s)))
        Y.append(p.offset())
        kept += 1

    # Build label space: keep hypernyms with frequency >= min_freq
    from collections import Counter
    cnt = Counter(Y)
    kept_labels = {k for k,v in cnt.items() if v >= min_freq}
    X2, Y2 = [], []
    for (sid, txt), y in zip(X, Y):
        if y in kept_labels:
            X2.append((sid, txt))
            Y2.append(y)
    # Map to contiguous label ids
    label_list = sorted(list(kept_labels))
    y2id = {y:i for i,y in enumerate(label_list)}
    Y_ids = [y2id[y] for y in Y2]
    return X2, np.array(Y_ids, dtype=np.int64), y2id

class WNHypDataset(Dataset):
    def __init__(self, feats: np.ndarray, labels: np.ndarray):
        self.feats = torch.tensor(feats, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    def __len__(self): return self.labels.size(0)
    def __getitem__(self, i):
        return self.feats[i], self.labels[i]

def split_indices(N, ratios=(0.8, 0.1, 0.1), seed=42):
    r_tr, r_va, r_te = ratios
    idx = list(range(N))
    rng = random.Random(seed); rng.shuffle(idx)
    n_tr = int(N * r_tr); n_va = int(N * r_va)
    return idx[:n_tr], idx[n_tr:n_tr+n_va], idx[n_tr+n_va:]


# ===================== Model =====================
class WNHypModel(nn.Module):
    """
    text -> proj(->512) -> block(HAMN/HypAttn/HypNN/MHN_Euc) -> Linear(num_labels)
    """
    def __init__(self, d_in: int, n_labels: int,
                 baseline: str = "HAMN",
                 hyper_c: float = 1.0,
                 clip_r: float = 0.9,
                 lr: float = 1.0,
                 tau: float = 1.0):
        super().__init__()
        self.baseline = baseline
        self.d_mid = 512
        self.proj = nn.Linear(d_in, self.d_mid)

        if baseline == "HAMN":
            assert Hyper_HopfieldLayer is not None, "hyper_hflayers not found"
            self.block = Hyper_HopfieldLayer(
                input_size=self.d_mid, hidden_size=self.d_mid, output_size=self.d_mid,
                num_heads=4, hyper_c=hyper_c,
                clip_r=1.0 / math.sqrt(hyper_c) * clip_r,
                lr=lr, association_activation='relu', dropout=0.2,
                quantity=n_labels, batch_first=True,
                train_c=False, input_as_hyper=False, out_as_hyper=False
            )
        elif baseline == "HypAttn":
            self.block = HyperbolicAttentionLayer(feat_dim=self.d_mid,n_mem=n_labels,c=hyper_c,tau_init=tau,dropout=0.1)
        elif baseline == "HypNN":
            self.block = HypNNBlock(feat_dim=self.d_mid, hidden=512, c=hyper_c, dropout=0.1)
        elif baseline == "MHN_Euc":
            assert HopfieldLayer is not None, "hflayers not found"
            self.block = HopfieldLayer(
                input_size=self.d_mid, hidden_size=self.d_mid, output_size=self.d_mid,
                num_heads=8, scaling=0.01, dropout=0.2,
                association_activation='relu', quantity=n_labels, batch_first=True
            )
        elif baseline == "UHop":
            self.block = UHopBlock(
                feat_dim=self.d_mid,
                n_mem=n_labels,
                n_heads=4,
                dropout=0.1,
                mode="softmax",
                kernel="lin",
                scale=None,
            )
        else:
            raise ValueError(baseline)

        self.head = nn.Linear(self.d_mid, n_labels)

    @torch.no_grad()
    def maybe_update_class_prototypes(self, class_proto: Optional[torch.Tensor]):
        if class_proto is None: return
        if hasattr(self.block, "set_memory"):
            self.block.set_memory(class_proto.to(next(self.parameters()).device))

    def forward(self, x):  # x: [B, d_in]
        z = self.proj(x)           # [B,512]
        z = z.unsqueeze(1)         # [B,1,512]
        if self.baseline in ("HAMN", "MHN_Euc", "UHop"):
            z = self.block(z).squeeze(1)  # both return [B,1,512]
        elif self.baseline == "HypAttn":
            z = self.block(z.squeeze(1))  # expects [B,512] -> [B,512]
        elif self.baseline == "HypNN":
            z = self.block(z.squeeze(1))  # [B,512]
        logits = self.head(z)
        return logits


# ===================== Training / Eval =====================
@torch.no_grad()
def compute_class_prototypes(feats: torch.Tensor, labels: torch.Tensor, n_labels: int) -> torch.Tensor:
    """Simple mean prototype per class in Euclidean space, then L2 normalize."""
    D = feats.size(1)
    proto = torch.zeros(n_labels, D, device=feats.device)
    counts = torch.zeros(n_labels, device=feats.device)
    for i in range(n_labels):
        mask = (labels == i)
        if mask.any():
            proto[i] = feats[mask].mean(dim=0)
            counts[i] = mask.sum()
    proto = F.normalize(proto, dim=1, eps=1e-6)
    return proto

def train_epoch(model, loader, opt, criterion):
    model.train()
    running = 0.0
    for x, y in tqdm(loader, desc="Train", leave=False):
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        opt.step()
        running += loss.item()
    return running / max(1, len(loader))

@torch.no_grad()
def evaluate(model, loader, label_id_to_synset: Dict[int, int]):
    model.eval()
    correct, total = 0, 0

    from collections import deque


    off2syn = {}
    for syn in wn.all_synsets(pos='n'):
        off2syn[syn.offset()] = syn

    def shortest_dag_distance(off_a, off_b):

        if off_a == off_b: return 0
        sa, sb = off2syn.get(off_a), off2syn.get(off_b)
        if (sa is None) or (sb is None): return 10  # fallback
        # BFS
        def neighbors(s):
            return set([x.offset() for x in s.hypernyms() + s.hyponyms()])
        visited = set([off_a])
        q = deque([(off_a, 0)])
        while q:
            cur, d = q.popleft()
            if cur == off_b: return d
            s = off2syn.get(cur)
            if s is None: continue
            for nb in neighbors(s):
                if nb not in visited:
                    visited.add(nb); q.append((nb, d+1))
        return 10

    sum_hdist = 0.0
    for x, y in tqdm(loader, desc="Eval", leave=False):
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        for pi, yi in zip(pred.tolist(), y.tolist()):
            off_p = list(label_id_to_synset.keys())[pi]
            off_t = list(label_id_to_synset.keys())[yi]
            sum_hdist += shortest_dag_distance(off_p, off_t)

    acc = correct / max(1, total)
    avg_hdist = sum_hdist / max(1, total)
    return acc, avg_hdist


def learn_kernel_uhop_text(model: WNHypModel,
                           loader: DataLoader,
                           device: torch.device,
                           kernel_epoch: int = 50):

    assert isinstance(model.block, UHopBlock), "learn_kernel_uhop_text 只适用于 UHopBlock."


    kernel_params = list(model.block.hop.kernel.parameters())
    opt = optim.SGD(kernel_params, lr=0.1, momentum=0.9)

    model.train()
    for epoch in tqdm(range(kernel_epoch), desc="UHop kernel pretrain (WordNet)", unit="epoch"):
        unif_losses = []
        for x, y in loader:
            x = x.to(device)

            opt.zero_grad()


            with torch.no_grad():
                z = model.proj(x)       # [B,512]
            z = z.detach()


            Y = z.unsqueeze(0)          # [1, B, 512]
            memory = model.block.kernel_forward(Y)   # [1, B, 512]

            s = memory.squeeze(0)       # [B,512]
            s = F.normalize(s, dim=-1)
            loss = uniform_loss(s)

            loss.backward()
            opt.step()
            unif_losses.append(loss.item())

        print(f"[UHop kernel pretrain] epoch {epoch+1}/{kernel_epoch}, "
              f"uniform loss = {np.mean(unif_losses):.4f}")


# ===================== Euclidean ontology encoder (WordNet graph) =====================
def build_ontology_graph(offsets, pos: str = "n"):

    offsets = sorted(list(set(offsets)))
    off2idx = {off: i for i, off in enumerate(offsets)}
    idx2off = offsets

    edges = set()
    for off in offsets:
        try:
            s = wn.synset_from_pos_and_offset(pos, off)
        except Exception:
            continue
        neighs = s.hypernyms() + s.hyponyms()
        for nb in neighs:
            off_nb = nb.offset()
            if off_nb in off2idx:
                i = off2idx[off]
                j = off2idx[off_nb]
                if i != j:
                    edges.add((i, j))
                    edges.add((j, i))
    return off2idx, idx2off, list(edges)


class OntologyEmbedding(nn.Module):

    def __init__(self, n_nodes: int, dim: int = 64):
        super().__init__()
        self.emb = nn.Embedding(n_nodes, dim)
        nn.init.xavier_uniform_(self.emb.weight)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        x = self.emb(idx)
        return F.normalize(x, dim=-1)


def train_ontology_encoder(
    edges,
    n_nodes: int,
    dim: int = 64,
    neg_k: int = 10,
    epochs: int = 5,
    batch_size: int = 1024,
    lr: float = 1e-2,
    device: torch.device = device,
):

    if len(edges) == 0:
        raise ValueError("Ontology graph has no edges.")

    model = OntologyEmbedding(n_nodes, dim).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    edges_np = np.array(edges, dtype=np.int64)
    num_edges = edges_np.shape[0]

    for ep in range(1, epochs + 1):
        perm = np.random.permutation(num_edges)
        edges_shuf = edges_np[perm]
        total_loss = 0.0
        n_steps = 0

        for i in range(0, num_edges, batch_size):
            batch = edges_shuf[i:i + batch_size]
            if batch.shape[0] == 0:
                continue

            anchors = torch.from_numpy(batch[:, 0]).to(device=device, dtype=torch.long)
            pos     = torch.from_numpy(batch[:, 1]).to(device=device, dtype=torch.long)

            B = anchors.size(0)
            neg = torch.randint(0, n_nodes, (B, neg_k), device=device, dtype=torch.long)

            z_a = model(anchors)              # [B, d]
            z_p = model(pos)                  # [B, d]
            z_n = model(neg)                  # [B, neg_k, d]

            sim_pos = (z_a * z_p).sum(-1, keepdim=True)          # [B, 1]
            sim_neg = (z_a.unsqueeze(1) * z_n).sum(-1)           # [B, neg_k]

            logits = torch.cat([sim_pos, sim_neg], dim=1)        # [B, 1+neg_k]
            labels = torch.zeros(B, dtype=torch.long, device=device)

            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss.item()
            n_steps += 1

        print(f"[OntEnc] epoch {ep:02d}/{epochs}  loss={total_loss / max(1, n_steps):.4f}")

    with torch.no_grad():
        all_idx = torch.arange(n_nodes, device=device, dtype=torch.long)
        emb = model(all_idx).cpu().numpy().astype(np.float32)
    return emb

def split_decay(named_params):

    decay, no_decay = [], []
    for n, p in named_params:
        if (not p.requires_grad):
            continue

        if p.ndim == 1 or n.endswith(".bias") or "norm" in n.lower():
            no_decay.append(p)
        else:
            decay.append(p)
    return decay, no_decay


# ===================== Main =====================
def main(
    primary_parent_strategy="first",   # 'first' | 'max_depth'
    encoder_mode="auto",               # 'auto' -> try sbert, else tfidf
    min_label_freq=5,
    max_nodes=20000,
    batch_size=256,
    epochs=15,
    baselines=("MHN_Euc","HypAttn","HypNN","HAMN"),
    c_hyp=1.0, clip_r=0.9, lr_cccp=1.0, tau=1.0,
    out_dir="wn_hypernym_results",
    use_ontology_encoder: bool = False,
    ontology_dim: int = 64,
    ontology_epochs: int = 5,
    ontology_neg_k: int = 10,
    baselines_use_ontology=("MHN_Euc","HypAttn","HypNN","HAMN"),
):
    os.makedirs(out_dir, exist_ok=True)
    # NLTK data
    try:
        _ = wn.synsets("dog")
    except LookupError:
        nltk.download('wordnet')

    # -------- Build pairs (synset -> primary hypernym) --------
    X_pairs, y_ids, y2id = collect_wordnet_pairs(
        pos="n", primary=primary_parent_strategy, min_freq=min_label_freq, max_nodes=max_nodes
    )
    texts = [t for (_off, t) in X_pairs]
    label_id_to_synset = {i: off for off,i in {off:idx for off,idx in y2id.items()}.items()}
    n_labels = len(y2id)
    print(f"[Data] samples={len(texts)}, labels={n_labels} (min_freq={min_label_freq})")

    enc = TextEncoder(mode=encoder_mode)
    feats_text = enc.fit_transform(texts).astype(np.float32)
    d_in_text = feats_text.shape[1]
    print(f"[Encoder] mode={enc.mode}, d_in={d_in_text}")


    feats_plain = feats_text
    feats_with_ont   = None          


    # -------- Splits -------
    idx_tr, idx_va, idx_te = split_indices(len(texts), ratios=(0.8, 0.1, 0.1), seed=42)


    if use_ontology_encoder:
        print("[OntEnc] building WordNet ontology graph and training Euclidean embeddings ...")


        input_offsets = [sid for (sid, _txt) in X_pairs]
        label_offsets = list(y2id.keys())
        all_offsets = set(input_offsets) | set(label_offsets)

        off2idx, idx2off, edges = build_ontology_graph(all_offsets, pos="n")
        print(f"[OntEnc] nodes={len(off2idx)}, edges={len(edges)}")

        ont_emb = train_ontology_encoder(
            edges=edges,
            n_nodes=len(off2idx),
            dim=ontology_dim,
            neg_k=ontology_neg_k,
            epochs=ontology_epochs,
            batch_size=1024,
            lr=1e-2,
            device=device,
        )  # [n_nodes, ontology_dim]

       
        ont_feats = np.zeros((len(X_pairs), ontology_dim), dtype=np.float32)
        for i, (sid, _txt) in enumerate(X_pairs):
            nid = off2idx.get(sid, None)
            if nid is not None:
                ont_feats[i] = ont_emb[nid]

       
        feats_with_ont = np.concatenate([feats_text, ont_feats], axis=1).astype(np.float32)
        print(f"[OntEnc] stacked feature dim: {feats_with_ont.shape[1]}")


    all_summaries = {}
    histories = {}

    for baseline in baselines:
   
        use_ont_here = (
            use_ontology_encoder 
            and (feats_with_ont is not None) 
            and (baseline in baselines_use_ontology)
        )


        name_for_logs = baseline + "+OntEuc" if use_ont_here else baseline
        print(f"\n========== Baseline: {name_for_logs} ==========")


        if use_ont_here:
            feats_used = feats_with_ont
        else:
            feats_used = feats_plain

        d_in = feats_used.shape[1]


        tr_ds = WNHypDataset(feats_used[np.array(idx_tr)], y_ids[np.array(idx_tr)])
        va_ds = WNHypDataset(feats_used[np.array(idx_va)], y_ids[np.array(idx_va)])
        te_ds = WNHypDataset(feats_used[np.array(idx_te)], y_ids[np.array(idx_te)])
        tr_ld = DataLoader(tr_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
        va_ld = DataLoader(va_ds, batch_size=batch_size, shuffle=False, num_workers=2)
        te_ld = DataLoader(te_ds, batch_size=batch_size, shuffle=False, num_workers=2)


        model = WNHypModel(
            d_in=d_in, n_labels=n_labels, baseline=baseline,
            hyper_c=c_hyp, clip_r=clip_r, lr=lr_cccp, tau=tau
        ).to(device)


        if baseline == "UHop":
 
            kernel_ckpt = os.path.join(out_dir, "uhop_kernel_only_wordnet.pth")

            if os.path.isfile(kernel_ckpt):
   
                print("=> loading UHop kernel checkpoint", kernel_ckpt)
                model.load_state_dict(torch.load(kernel_ckpt, map_location=device))
            else:
   
                print("=> UHop baseline: pretraining kernel with uniform loss (WordNet)")
                learn_kernel_uhop_text(model, tr_ld, device, kernel_epoch=50)

                os.makedirs(out_dir, exist_ok=True)
                torch.save(model.state_dict(), kernel_ckpt)
                print("=> saved UHop kernel checkpoint to", kernel_ckpt)

  

        decay, no_decay = split_decay(model.named_parameters())
        opt = optim.AdamW(
            [{"params": decay,    "lr": 2e-3, "weight_decay": 5e-4},
             {"params": no_decay, "lr": 2e-3, "weight_decay": 0.0}],
            betas=(0.9,0.999), eps=1e-8
        )
        criterion = nn.CrossEntropyLoss()

        best_val, best_ckpt = -1.0, os.path.join(out_dir, f"{name_for_logs}_best.pth")
        hist = {"train_loss": [], "val_acc": [], "val_hdist": []}

        for ep in range(1, epochs+1):
            tr_loss = train_epoch(model, tr_ld, opt, criterion)
            val_acc, val_hdist = evaluate(model, va_ld, label_id_to_synset)

            hist["train_loss"].append(float(tr_loss))
            hist["val_acc"].append(float(val_acc))
            hist["val_hdist"].append(float(val_hdist))

            print(f"[{name_for_logs}] Epoch {ep:02d} | loss {tr_loss:.4f} | "
                  f"val_acc {val_acc:.4f} | val_hdist {val_hdist:.3f}")

            score = val_acc - 0.01 * val_hdist
            if score > best_val:
                best_val = score
                torch.save(model.state_dict(), best_ckpt)

        histories[name_for_logs] = hist

        model.load_state_dict(torch.load(best_ckpt, map_location=device))
        te_acc, te_hdist = evaluate(model, te_ld, label_id_to_synset)
        print(f"[{name_for_logs}] TEST  acc={te_acc:.4f}  avg_hdist={te_hdist:.3f}")

        all_summaries[name_for_logs] = {
            "test_acc": float(te_acc),
            "test_avg_hdist": float(te_hdist),
            "val_best_score": float(best_val),
            "encoder": enc.mode,
            "labels": n_labels,
            "samples": len(texts),
            "use_ontology_encoder": bool(use_ont_here),
            "ontology_dim": ontology_dim if use_ont_here else 0,
        }


    ts = time.strftime("%Y%m%d_%H%M%S")
    fig, axs = plt.subplots(3, 1, figsize=(9, 12), sharex=True)

    # 1) train loss
    for bl, h in histories.items():
        axs[0].plot(range(1, len(h["train_loss"])+1), h["train_loss"], label=bl, linewidth=1.8)
    axs[0].set_ylabel("Train Loss")
    axs[0].grid(True, linestyle="--", alpha=0.3)
    axs[0].legend()

    # 2) val_acc
    for bl, h in histories.items():
        axs[1].plot(range(1, len(h["val_acc"])+1), h["val_acc"], label=bl, linewidth=1.8)
    axs[1].set_ylabel("Val Acc")
    axs[1].grid(True, linestyle="--", alpha=0.3)

    # 3) val_hdist
    for bl, h in histories.items():
        axs[2].plot(range(1, len(h["val_hdist"])+1), h["val_hdist"], label=bl, linewidth=1.8)
    axs[2].set_xlabel("Epoch")
    axs[2].set_ylabel("Val Hier Dist (↓)")
    axs[2].grid(True, linestyle="--", alpha=0.3)

    plt.tight_layout()
    curves_path = os.path.join(out_dir, f"curves_all_{ts}.png")
    plt.savefig(curves_path, dpi=200)
    print("Saved curves ->", curves_path)
    summary_path = os.path.join(out_dir, "summary.json")
    record = {
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
        "summaries": all_summaries,   
    }


    old = []
    if os.path.exists(summary_path):
        try:
            with open(summary_path, "r", encoding="utf-8") as f:
                old = json.load(f)

                if isinstance(old, dict):
                    old = [old]
        except json.JSONDecodeError:
            old = []

    old.append(record)


    tmp_path = summary_path + ".tmp"
    os.makedirs(out_dir, exist_ok=True)
    with open(tmp_path, "w", encoding="utf-8") as f:
        json.dump(old, f, indent=2, ensure_ascii=False)
    os.replace(tmp_path, summary_path)

    print("Appended. Summary ->", summary_path)



if __name__ == "__main__":
    main(
        primary_parent_strategy="first",
        encoder_mode="auto",
        min_label_freq=5,
        max_nodes=20000,
        batch_size=128,
        epochs=100,
        baselines=("MHN_Euc","HypAttn","HypNN","HAMN","UHop"),
        c_hyp=0.9, clip_r=0.9, lr_cccp=1.0, tau=1.0,
        out_dir="wn_hypernym_results6",
        use_ontology_encoder=True,      # ★ 开启欧式 ontology encoder
        ontology_dim=64,
        ontology_epochs=15,
        ontology_neg_k=10,
        baselines_use_ontology=("MHN_Euc","HypAttn","HypNN","HAMN","UHop"),
    )

