# kgc_wn18rr.py  (HypAttn sharpened, HAMN smaller eta, 5-panel curves)
import os, sys, json, math, time, argparse, random, contextlib
from typing import List, Tuple, Dict, Optional
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Utils
# =========================
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def is_windows(): return (os.name == "nt") or sys.platform.startswith("win")
def ensure_dir(p): os.makedirs(p, exist_ok=True); return p

def find_wn18rr_dir(root):
    cand = [os.path.join(root, "WN18RR"), os.path.join(root, "wn18rr")]
    for d in cand:
        if os.path.isdir(d): return d
    return cand[0]

def autocast_ctx(enabled=True, dtype=torch.float16):
    if torch.cuda.is_available() and enabled:
        return torch.amp.autocast(device_type="cuda", dtype=dtype)
    return contextlib.nullcontext()

# =========================
# Poincaré ball ops
# =========================
class PoincareBall:
    def __init__(self, c: float, eps: float = 1e-6):
        self.c = float(c); self.eps = eps
        self.radius = 1.0 / (self.c ** 0.5); self._M = 1.0 - 1e-5

    def _proj_with_margin(self, x):
        r = x.norm(dim=-1, keepdim=True).clamp_min(self.eps)
        max_r = self._M * self.radius
        scale = torch.where(r > max_r, max_r / r, torch.ones_like(r))
        return x * scale

    def lambda_x(self, x):
        x2 = (x * x).sum(dim=-1, keepdim=True)
        return 2.0 / (1.0 - self.c * x2).clamp_min(1e-6)

    def mobius_add(self, x, y):
        c = self.c
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        xy = (x * y).sum(dim=-1, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        den = (1 + 2 * c * xy + (c**2) * x2 * y2).clamp_min(1e-6)
        out = num / den
        return self._proj_with_margin(out)

    def mobius_neg(self, x): return self._proj_with_margin(-x)

    def exp0(self, v):
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = (self.c ** 0.5) * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        x = coef * v
        return self._proj_with_margin(x)

    def log0(self, x):
        x = self._proj_with_margin(x)
        xnorm = x.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * xnorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = torch.atanh(arg) / ((self.c ** 0.5) * xnorm)
        return coef * x

    def exp_p(self, p, v):
        lam = self.lambda_x(p)
        vnorm = v.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        t = 0.5 * (self.c ** 0.5) * lam * vnorm
        coef = torch.tanh(t) / ((self.c ** 0.5) * vnorm)
        delta = coef * v
        return self.mobius_add(p, delta)

    def log_p(self, p, x):
        lam = self.lambda_x(p)
        y = self.mobius_add(self.mobius_neg(p), x)
        ynorm = y.norm(dim=-1, keepdim=True).clamp_min(1e-12)
        arg = (self.c ** 0.5) * ynorm
        arg = torch.clamp(arg, 0.0, 1.0 - 1e-7)
        coef = (2.0 / ((self.c ** 0.5) * lam)) * (torch.atanh(arg) / ynorm)
        return coef * y

    def dist(self, x, y):
        c = self.c
        diff2 = ((x - y) ** 2).sum(dim=-1, keepdim=True)
        x2 = (x * x).sum(dim=-1, keepdim=True)
        y2 = (y * y).sum(dim=-1, keepdim=True)
        num = 2 * c * diff2
        den = ((1 - c * x2) * (1 - c * y2)).clamp_min(self.eps)
        z = 1 + num / den
        return torch.acosh(z.clamp_min(1 + 1e-6))

    def gyr(self, a, b, w):
        ab = self.mobius_add(a, b)
        bw = self.mobius_add(b, w)
        a_bw = self.mobius_add(a, bw)
        return self.mobius_add(self.mobius_neg(ab), a_bw)

    def ptransp(self, x, y, v):
        lam_x = self.lambda_x(x); lam_y = self.lambda_x(y)
        gyred = self.gyr(y, self.mobius_neg(x), v)
        return (lam_x / lam_y) * gyred

# =========================
# Decoders (4 baselines)
# =========================
class EuclidMemoryAttention(nn.Module):  # MHN_Euc
    def __init__(self, dim, n_rel, K=16, dropout=0.1):
        super().__init__()
        self.dim = dim; self.K = K
        self.mem = nn.Embedding(n_rel * K, dim)
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.ln = nn.LayerNorm(dim); self.drop = nn.Dropout(dropout)
        nn.init.normal_(self.mem.weight, std=0.02)

    def forward(self, q, r):
        B, d = q.shape; q = self.ln(q)
        ql = self.q_proj(q)
        M = self.mem.weight.view(-1, self.K, d)[r]        # [B,K,d]
        Kmat = self.k_proj(M); Vmat = self.v_proj(M)
        attn = torch.softmax((ql.unsqueeze(1)*Kmat).sum(-1)/math.sqrt(d), dim=1)
        z = (attn.unsqueeze(-1)*Vmat).sum(1)
        return z

class HyperbolicAttentionLayer(nn.Module):  # HypAttn (sharpened, no residual)
    def __init__(
        self, dim, n_rel, K=16, c=1.0, tau=2.0, dropout=0.1,
        use_d2=True,                # True:  -β·d^2
        attn_norm="zscore_max",     # "zscore_max" | "zscore" | "center" | "max" | "none"
        beta_init=0.5,            
        mem_init_std=0.01,          
        clip_tan=1.0                
    ):
        super().__init__()
        self.ball = PoincareBall(c)
        self.dim = dim; self.K = K
        self.mem = nn.Embedding(n_rel * K, dim)
        nn.init.normal_(self.mem.weight, std=mem_init_std)

        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)

        self.ln_q = nn.LayerNorm(dim)
        self.ln_kv = nn.LayerNorm(dim)
        self.drop = nn.Dropout(dropout)


        self.tau = float(tau)


        self.log_beta = nn.Parameter(torch.log(torch.tensor(float(beta_init))))
        self.beta_min, self.beta_max = 0.05, 10.0

        self.use_d2 = bool(use_d2)
        self.attn_norm = attn_norm


        self.q_log_scale = nn.Parameter(torch.tensor(0.0))
        self.k_log_scale = nn.Parameter(torch.tensor(1.1))
        self.v_log_scale = nn.Parameter(torch.tensor(1.1))


        self.clip_tan = clip_tan
        self.tan_eps = 1e-5

    def _tan_clip(self, v: torch.Tensor) -> torch.Tensor:
        if self.clip_tan is None or self.clip_tan <= 0:
            return v
        norm = v.norm(dim=-1, keepdim=True)
        scale = self.clip_tan / (norm + self.tan_eps)
        scale = torch.clamp(scale, max=1.0)
        return v * scale

    def _normalize_scores(self, scores: torch.Tensor):
        # scores: [B, K]
        if self.attn_norm in ("zscore", "zscore_max"):
            mu = scores.mean(dim=1, keepdim=True)
            sigma = scores.std(dim=1, keepdim=True).clamp_min(1e-6)
            scores = (scores - mu) / sigma
        elif self.attn_norm == "center":
            scores = scores - scores.mean(dim=1, keepdim=True)

        if self.attn_norm in ("max", "zscore_max"):
            scores = scores - scores.max(dim=1, keepdim=True).values
        return scores

    def forward(self, q, r):
        # q: [B, d], r: [B]
        B, d = q.shape

 
        q_t = self.ln_q(q)                          # [B,d]
        M    = self.mem.weight.view(-1, self.K, d)[r]  # [B,K,d]
        k_t  = self.ln_kv(self.k_proj(M))
        v_t  = self.ln_kv(self.v_proj(M))


        s_q = torch.exp(self.q_log_scale).clamp(0.25, 10.0)
        s_k = torch.exp(self.k_log_scale).clamp(0.25, 10.0)
        s_v = torch.exp(self.v_log_scale).clamp(0.25, 10.0)

        q_tan = s_q * self.q_proj(q_t)   # [B,d]
        k_tan = s_k * k_t                # [B,K,d]
        v_tan = s_v * v_t                # [B,K,d]

        q_tan = self._tan_clip(q_tan)
        k_tan = self._tan_clip(k_tan)
        v_tan = self._tan_clip(v_tan)


        q_b = self.ball.exp0(q_tan)      # [B,d]
        k_b = self.ball.exp0(k_tan)      # [B,K,d]
        v_b = self.ball.exp0(v_tan)      # [B,K,d]


        q_rep = q_b.unsqueeze(1).expand_as(k_b)        # [B,K,d]
        d_geo = self.ball.dist(q_rep, k_b).squeeze(-1) # [B,K]

        beta  = torch.exp(self.log_beta).clamp(self.beta_min, self.beta_max)
        scores = - beta * (d_geo**2 if self.use_d2 else d_geo)  # [B,K]
        scores = self._normalize_scores(scores)


        tau = max(0.1, min(5.0, float(self.tau)))
        attn = torch.softmax(tau * scores, dim=1)      # [B,K]
        attn = self.drop(attn)

 
        log_q_v = self.ball.log_p(q_b.unsqueeze(1), v_b)          # [B,K,d]
        z_tan   = torch.bmm(attn.unsqueeze(1), log_q_v).squeeze(1)# [B,d]
        z_ball  = self.ball.exp_p(q_b, z_tan)                     # [B,d]
        out     = self.ball.log0(z_ball)                          # [B,d] in T_0


        return out

class HypNNBlock(nn.Module):  # HypNN
    def __init__(self, dim, hidden=512, c=1.0, dropout=0.1):
        super().__init__()
        self.ball = PoincareBall(c)
        self.fc1 = nn.Linear(dim, hidden, bias=True)
        self.fc2 = nn.Linear(hidden, dim, bias=True)
        self.ln1 = nn.LayerNorm(hidden); self.ln2 = nn.LayerNorm(dim)
        self.drop = nn.Dropout(dropout)
        nn.init.kaiming_uniform_(self.fc1.weight, a=0.2)
        nn.init.zeros_(self.fc1.bias); nn.init.zeros_(self.fc2.bias)

    def forward(self, q, r=None):
        x = self.ln2(q); x_b = self.ball.exp0(x)
        h = F.relu(self.ln1(self.fc1(self.ball.log0(x_b))))
        h = self.drop(h)
        y = self.fc2(h)
        return y

class HAMN_Algo1(nn.Module):

    def __init__(
        self, dim, n_rel, K=16, c=1.0, theta=2.0,
        iters=1, eta=1.0, eps=1e-5, dropout=0.0,
        clip_tan=3.0    
    ):
        super().__init__()
        self.dim = dim
        self.K = K
        self.ball = PoincareBall(c)
        self.ln = nn.LayerNorm(dim)


        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.mem = nn.Embedding(n_rel * K, dim)
        self.mem_proj = nn.Linear(dim, dim, bias=False)

        self.drop = nn.Dropout(dropout)

        self.theta = float(theta)
        self.iters = int(iters)
        self.eta = float(eta)  
        self.eps = float(eps)


        self.clip_tan = clip_tan
        self.tan_eps = 1e-5

        nn.init.normal_(self.mem.weight, std=0.02)

    def _tan_clip(self, v: torch.Tensor) -> torch.Tensor:
        if self.clip_tan is None or self.clip_tan <= 0:
            return v
        norm = v.norm(dim=-1, keepdim=True)
        scale = self.clip_tan / (norm + self.tan_eps)
        scale = torch.clamp(scale, max=1.0)
        return v * scale

    def forward(self, q, r):                     # q:[B,d]（e_h + e_r）, r:[B]
        B, d = q.shape


        q0 = self.q_proj(self.ln(q))             
        q0 = self._tan_clip(q0)
        R  = self.ball.exp0(q0)                  # [B,d] on ball

        M = self.mem.weight.view(-1, self.K, d)[r]     # [B,K,d] tangent params
        M_t = self.mem_proj(M)
        M_t = self._tan_clip(M_t)
        Y = self.ball.exp0(M_t)                        # [B,K,d] on ball

        theta = max(0.1, min(5.0, float(self.theta)))

        # 迭代（通常 1 步就够）
        for _ in range(max(1, self.iters)):
            # ---------- similarity & softmax weights ----------
            R_rep = R.unsqueeze(1).expand(-1, self.K, -1)          # [B,K,d]
            d_geod = self.ball.dist(R_rep, Y).squeeze(-1)          # [B,K]
            val = -torch.cosh(d_geod)                              # [B,K]

            z = theta * val
            z = z - z.max(dim=1, keepdim=True).values
            p = torch.softmax(z, dim=1)                            # [B,K]

            # ---------- intrinsic gradient of concave part ----------
            # log_R(Y)/d
            u = self.ball.log_p(R.unsqueeze(1), Y)                 # [B,K,d] in T_R
            u_dir = u / d_geod.clamp_min(1e-9).unsqueeze(-1)       # [B,K,d]
            # a = - Σ p_i * sinh(d_i) * u_dir_i
            a = - (p.unsqueeze(-1) * torch.sinh(d_geod).unsqueeze(-1) * u_dir).sum(dim=1)  # [B,d]

            # ---------- CCCP step with intrinsic regularizer ----------
            # PT_{R→0}( -a ) = (λ_R / 2) * ( -a )
            lam = self.ball.lambda_x(R)                            # [B,1]
            v0 = (lam / 2.0) * (-a)                                # [B,d] in T_0


            R_next = self.ball.exp0(self.eta * v0)                 # [B,d] on ball

            # ---------- stopping rule ----------
            if self.ball.dist(R, R_next).mean() < self.eps:
                R = R_next
                break
            R = R_next


        return self.ball.log0(R)                                   # [B,d]

# =========================
# KGE Model
# =========================
class KGEModel(nn.Module):
    def __init__(self, n_ent, n_rel_eff, dim=200, decoder="HypAttn",
                 K=16, n_heads=4, c=1.0, tau=5.0, dropout=0.1):
        super().__init__()
        self.decoder_name = decoder
        self.dim = dim; self.c = c
        self.ball = PoincareBall(c)
        self.E = nn.Embedding(n_ent, dim)
        self.R = nn.Embedding(n_rel_eff, dim)
        nn.init.normal_(self.E.weight, std=0.01)
        nn.init.normal_(self.R.weight, std=0.01)


        self.clip_tan_ent = 2.0
        self.tan_eps_ent = 1e-5

        if decoder == "MHN_Euc":
            self.dec = EuclidMemoryAttention(dim, n_rel_eff, K=K, dropout=dropout)
        elif decoder == "HypAttn":
            self.dec = HyperbolicAttentionLayer(
                dim, n_rel_eff, K=K, c=c,
                tau=tau, dropout=dropout,
                use_d2=True,
                attn_norm="zscore_max",
                beta_init=0.5,
                mem_init_std=0.01,
                clip_tan=1.0
            )
        elif decoder == "HypNN":
            self.dec = HypNNBlock(dim, hidden=max(256, dim), c=c, dropout=dropout)
        elif decoder == "HAMN":
            self.dec = HAMN_Algo1(
                dim, n_rel_eff, K=K, c=c,
                theta=tau,
                iters=1,
                eta=1.0,      
                dropout=dropout,
                clip_tan=1.0
            )
        else:
            raise ValueError(decoder)

    def _tan_clip_ent(self, v: torch.Tensor) -> torch.Tensor:
        if self.clip_tan_ent is None or self.clip_tan_ent <= 0:
            return v
        norm = v.norm(dim=-1, keepdim=True)
        scale = self.clip_tan_ent / (norm + self.tan_eps_ent)
        scale = torch.clamp(scale, max=1.0)
        return v * scale

    def query_vector(self, h, r):
        eh = self.E(h); er = self.R(r)
        return eh + er

    def score(self, h, r, t):
        q = self.query_vector(h, r)
        z = self.dec(q, r)

        # Euclidean MHN
        if self.decoder_name == "MHN_Euc":
            et = self.E(t)
            return -((z - et)**2).sum(-1)

        # HypNN: 
        if self.decoder_name == "HypNN":
            z_b = self.ball.exp0(z)
            t_b = self.ball.exp0(self.E(t))
            d2 = (self.ball.dist(z_b, t_b).squeeze(-1)**2)
            return -d2

        # HypAttn / HAMN: 
        z_b = self.ball.exp0(self._tan_clip_ent(z))
        t_b = self.ball.exp0(self._tan_clip_ent(self.E(t)))
        d2 = (self.ball.dist(z_b, t_b).squeeze(-1)**2)
        return -d2

    def score_all_tails(self, h, r, chunk=2048):
        q = self.query_vector(h, r)
        z = self.dec(q, r)
        Eall = self.E.weight
        scores = []

        # Euclidean
        if self.decoder_name == "MHN_Euc":
            for i in range(0, Eall.size(0), chunk):
                T = Eall[i:i+chunk]
                s = -(((z.unsqueeze(1) - T.unsqueeze(0))**2).sum(-1))
                scores.append(s)
            return torch.cat(scores, dim=1)


        if self.decoder_name == "HypNN":
            z_b = self.ball.exp0(z)
            for i in range(0, Eall.size(0), chunk):
                T = Eall[i:i+chunk]
                T_b = self.ball.exp0(T)
                d2 = (self.ball.dist(
                    z_b.unsqueeze(1).expand(-1, T_b.size(0), -1),
                    T_b.unsqueeze(0)
                ).squeeze(-1)**2)
                scores.append(-d2)
            return torch.cat(scores, dim=1)


        z_b = self.ball.exp0(self._tan_clip_ent(z))
        for i in range(0, Eall.size(0), chunk):
            T = self._tan_clip_ent(Eall[i:i+chunk])
            T_b = self.ball.exp0(T)
            d2 = (self.ball.dist(
                z_b.unsqueeze(1).expand(-1, T_b.size(0), -1),
                T_b.unsqueeze(0)
            ).squeeze(-1)**2)
            scores.append(-d2)
        return torch.cat(scores, dim=1)

# =========================
# Data I/O (robust dicts)
# =========================
def load_triples(path):
    tris = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            h, r, t = line.strip().split('\t')
            tris.append((h, r, t))
    return tris

def _try_int(x):
    try: return int(x)
    except Exception: return None

def load_dict(path):

    d = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s: continue
            parts = s.split()
            if len(parts) < 2:
                tok = parts[0]
                if tok not in d: d[tok] = len(d)
                continue
            a, b = parts[0], parts[1]
            ia, ib = _try_int(a), _try_int(b)
            if ia is not None and ib is None:
                d[b] = ia
            elif ib is not None and ia is None:
                d[a] = ib
            elif ib is not None:
                d[a] = ib
            else:
                if a not in d: d[a] = len(d)
                if b not in d: d[b] = len(d)
    return d

def build_or_load_mappings(root):
    et_path = os.path.join(root, "entities.dict")
    rl_path = os.path.join(root, "relations.dict")

    def rebuild_from_triples():
        print("[Info] rebuilding dicts from triples ...")
        all_tris = []
        for fn in ["train.txt", "valid.txt", "test.txt"]:
            p = os.path.join(root, fn)
            all_tris.extend(load_triples(p))
        ents = sorted(list({h for h,_,_ in all_tris} | {t for _,_,t in all_tris}))
        rels = sorted(list({r for _,r,_ in all_tris}))
        e2id = {e:i for i,e in enumerate(ents)}
        r2id = {r:i for i,r in enumerate(rels)}
        with open(et_path, "w", encoding="utf-8") as f:
            for e, i in e2id.items(): f.write(f"{i}\t{e}\n")
        with open(rl_path, "w", encoding="utf-8") as f:
            for r, i in r2id.items(): f.write(f"{i}\t{r}\n")
        return e2id, r2id

    if os.path.isfile(et_path) and os.path.isfile(rl_path):
        try:
            e2id = load_dict(et_path)
            r2id = load_dict(rl_path)
            return e2id, r2id
        except Exception as e:
            print("[warn] failed to parse existing dicts, will rebuild:", e)
            return rebuild_from_triples()
    else:
        return rebuild_from_triples()

def map_triples(tris, e2id, r2id):
    try:
        return np.array([(e2id[h], r2id[r], e2id[t]) for (h,r,t) in tris], dtype=np.int64)
    except KeyError as ke:
        raise KeyError(
            f"KeyError: {str(ke)}. *The contents of the dict do not match the expected triplets. You can delete the dict and let the script rebuild the data accordingly."
        )

def make_reciprocal(tr, n_rel):
    out = tr.copy()
    out = out[:, [2,1,0]]
    out[:,1] = out[:,1] + n_rel
    return out

# =========================
# Datasets & Filters
# =========================
class KGTrainDataset(Dataset):
    def __init__(self, triples, n_ent, neg_k=50):
        self.tr = triples.astype(np.int64)
        self.n_ent = n_ent; self.neg_k = neg_k
    def __len__(self): return self.tr.shape[0]
    def __getitem__(self, i):
        h, r, t = self.tr[i]
        neg = np.random.randint(0, self.n_ent, size=(self.neg_k,), dtype=np.int64)
        return (h, r, t, neg)

def build_filters(tr_all):
    hr2t = defaultdict(set)  # (h,r)->{t}
    for (h,r,t) in tr_all:
        hr2t[(int(h), int(r))].add(int(t))
    return hr2t

# =========================
# Evaluation (tail-only + AMP, with score_chunk)
# =========================
@torch.no_grad()
def evaluate_tail_only(model, triples, hr2t, batch=256, use_amp=True, score_chunk=2048):
    model.eval()
    ranks = []
    with autocast_ctx(enabled=use_amp, dtype=torch.float16):
        for i in tqdm(range(0, len(triples), batch), desc="Eval(T)", leave=False):
            chunk = triples[i:i+batch]                       # np.ndarray [B,3]
            h = torch.from_numpy(chunk[:,0].copy()).to(device=device, dtype=torch.long, non_blocking=True)
            r = torch.from_numpy(chunk[:,1].copy()).to(device=device, dtype=torch.long, non_blocking=True)
            t = torch.from_numpy(chunk[:,2].copy()).to(device=device, dtype=torch.long, non_blocking=True)

            scores = model.score_all_tails(h, r, chunk=score_chunk)  # [B,N]
            # filtered
            for j in range(scores.size(0)):
                key = (int(h[j]), int(r[j]))
                for tt in hr2t[key]:
                    if tt != int(t[j]):
                        scores[j, tt] = -1e9
            gold = scores[torch.arange(scores.size(0), device=device), t]
            rank = (scores > gold.unsqueeze(1)).sum(dim=1) + 1
            ranks.extend(rank.detach().cpu().tolist())
    ranks = np.asarray(ranks, dtype=np.int64)
    return {
        "MRR": float(np.mean(1.0 / ranks)),
        "H@1": float(np.mean(ranks <= 1)),
        "H@3": float(np.mean(ranks <= 3)),
        "H@10": float(np.mean(ranks <= 10)),
    }

def avg_metrics(m1, m2):
    return {k: 0.5*(m1[k]+m2[k]) for k in m1.keys()}

# =========================
# Train
# =========================
def train_one_epoch(model, loader, opt, neg_k=50, max_grad=5.0, use_amp=True):
    model.train()
    run = 0.0
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp and torch.cuda.is_available())
    for (h,r,t,neg) in tqdm(loader, desc="Train", leave=False):
        h = h.to(device, non_blocking=True)
        r = r.to(device, non_blocking=True)
        t = t.to(device, non_blocking=True)
        neg = neg.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with autocast_ctx(enabled=use_amp, dtype=torch.float16):
            pos = model.score(h, r, t)
            B, K = neg.size()
            H = h.view(-1,1).expand(-1,K).reshape(-1)
            R = r.view(-1,1).expand(-1,K).reshape(-1)
            Tn = neg.view(-1)
            neg_s = model.score(H, R, Tn).view(B, K)
            logits = torch.cat([pos.unsqueeze(1), neg_s], dim=1)
            target = torch.zeros_like(logits); target[:,0] = 1.0
            loss = F.binary_cross_entropy_with_logits(logits, target)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad)
        scaler.step(opt); scaler.update()
        run += loss.item()
    return run / max(1, len(loader))

# =========================
# Main
# =========================
def main(
    data_root="data",
    out_dir="runs_wn18rr",
    baselines=("MHN_Euc","HypAttn","HypNN","HAMN"),
    dim=200, c=1.0, tau=5.0, K=16, heads=4,
    batch_size=1024, neg_k=50, epochs=100, lr=2e-3, wd=1e-4, seed=42,
    reciprocal=True, eval_batch=256, eval_amp=True, score_chunk=2048
):
    set_seed(seed); ensure_dir(out_dir)

    ds_dir = find_wn18rr_dir(data_root)
    req = [os.path.join(ds_dir, x) for x in ["train.txt","valid.txt","test.txt"]]
    for p in req:
        if not os.path.isfile(p):
            raise FileNotFoundError(
                f"Missing file: {p}\n"
                " WN18RR in data_root/WN18RR/ or data_root/wn18rr/ （train.txt/valid.txt/test.txt）。"
            )


    e2id, r2id = build_or_load_mappings(ds_dir)


    train_tr_s = load_triples(os.path.join(ds_dir,"train.txt"))
    valid_tr_s = load_triples(os.path.join(ds_dir,"valid.txt"))
    test_tr_s  = load_triples(os.path.join(ds_dir,"test.txt"))


    def safe_map(tris):
        try:
            return map_triples(tris, e2id, r2id)
        except KeyError:
            print("[warn] Detected mismatch between *.dict and triples. Rebuilding dicts...")
            ents = sorted(list({h for h,_,_ in train_tr_s+valid_tr_s+test_tr_s} |
                               {t for _,_,t in train_tr_s+valid_tr_s+test_tr_s}))
            rels = sorted(list({r for _,r,_ in train_tr_s+valid_tr_s+test_tr_s}))
            e2 = {e:i for i,e in enumerate(ents)}
            r2 = {r:i for i,r in enumerate(rels)}
            with open(os.path.join(ds_dir,"entities.dict"), "w", encoding="utf-8") as f:
                for e,i in e2.items(): f.write(f"{i}\t{e}\n")
            with open(os.path.join(ds_dir,"relations.dict"), "w", encoding="utf-8") as f:
                for r,i in r2.items(): f.write(f"{i}\t{r}\n")
            return np.array([(e2[h], r2[r], e2[t]) for (h,r,t) in tris], dtype=np.int64)

    train_tr = safe_map(train_tr_s)
    valid_tr = safe_map(valid_tr_s)
    test_tr  = safe_map(test_tr_s)

    n_ent = int(max(train_tr[:,[0,2]].max(), valid_tr[:,[0,2]].max(), test_tr[:,[0,2]].max()) + 1)
    n_rel = int(max(train_tr[:,1].max(), valid_tr[:,1].max(), test_tr[:,1].max()) + 1)
    print(f"[Data] n_ent={n_ent}  n_rel={n_rel}  train={len(train_tr)}  valid={len(valid_tr)}  test={len(test_tr)}")

    # reciprocal
    if reciprocal:
        train_rec = make_reciprocal(train_tr, n_rel)
        valid_rec = make_reciprocal(valid_tr, n_rel)
        test_rec  = make_reciprocal(test_tr,  n_rel)
        n_rel_eff = n_rel * 2
        hr2t = build_filters(np.concatenate([train_tr, valid_tr, test_tr,
                                             train_rec, valid_rec, test_rec], axis=0))
        train_all = np.concatenate([train_tr, train_rec], axis=0)
    else:
        n_rel_eff = n_rel
        hr2t = build_filters(np.concatenate([train_tr, valid_tr, test_tr], axis=0))
        train_all = train_tr
        valid_rec = test_rec = None

    tr_ds = KGTrainDataset(train_all, n_ent=n_ent, neg_k=neg_k)
    nw = 0 if is_windows() else 2
    tr_ld = DataLoader(tr_ds, batch_size=batch_size, shuffle=True, num_workers=nw,
                       drop_last=False, pin_memory=torch.cuda.is_available())

    curves = {}; all_summaries = {}

    for base in baselines:
        print(f"\n========== Baseline: {base} ==========")
        if base == "HAMN":
            lr_cur = lr * 0.5         
        else:
            lr_cur = lr
        model = KGEModel(n_ent, n_rel_eff, dim=dim, decoder=base, K=K*2, n_heads=heads, c=c, tau=tau, dropout=0.1).to(device)
        opt = torch.optim.AdamW(model.parameters(), lr=lr_cur, weight_decay=wd, betas=(0.9,0.999), eps=1e-8)

        best_valid = -1.0; best_ckpt = os.path.join(out_dir, f"{base}_best.pth")
        hist = {"loss": [], "valid_MRR": [], "H@1": [], "H@3": [], "H@10": []}


        for ep in range(1, epochs+1):
            tr_loss = train_one_epoch(model, tr_ld, opt, neg_k=neg_k, use_amp=True)


            valid_tail = evaluate_tail_only(model, valid_tr, hr2t, batch=eval_batch, use_amp=eval_amp, score_chunk=score_chunk)
            if reciprocal:
                valid_head = evaluate_tail_only(model, valid_rec, hr2t, batch=eval_batch, use_amp=eval_amp, score_chunk=score_chunk)
                valid_metrics = avg_metrics(valid_tail, valid_head)
            else:
                valid_metrics = valid_tail

            score = valid_metrics["MRR"]
            print(f"[{base}] Ep{ep:03d}  loss={tr_loss:.4f}  "
                  f"valid MRR={score:.4f}  H@1={valid_metrics['H@1']:.4f}  "
                  f"H@3={valid_metrics['H@3']:.4f}  H@10={valid_metrics['H@10']:.4f}")
            hist["loss"].append(float(tr_loss))
            hist["valid_MRR"].append(float(score))
            hist["H@1"].append(float(valid_metrics["H@1"]))
            hist["H@3"].append(float(valid_metrics["H@3"]))
            hist["H@10"].append(float(valid_metrics["H@10"]))

            if score > best_valid:
                best_valid = score; torch.save(model.state_dict(), best_ckpt)


        model.load_state_dict(torch.load(best_ckpt, map_location=device))
        test_tail = evaluate_tail_only(model, test_tr, hr2t, batch=eval_batch, use_amp=eval_amp, score_chunk=score_chunk)
        if reciprocal:
            test_head = evaluate_tail_only(model, test_rec, hr2t, batch=eval_batch, use_amp=eval_amp, score_chunk=score_chunk)
            test_metrics = avg_metrics(test_tail, test_head)
        else:
            test_metrics = test_tail

        print(f"[{base}] TEST  MRR={test_metrics['MRR']:.4f}  "
              f"H@1={test_metrics['H@1']:.4f}  H@3={test_metrics['H@3']:.4f}  H@10={test_metrics['H@10']:.4f}")

        curves[base] = hist
        all_summaries[base] = {
            "valid_best_MRR": float(best_valid),
            "test_MRR": float(test_metrics["MRR"]),
            "test_H@1": float(test_metrics["H@1"]),
            "test_H@3": float(test_metrics["H@3"]),
            "test_H@10": float(test_metrics["H@10"]),
            "n_ent": n_ent, "n_rel_eff": n_rel_eff, "dim": dim, "K": K, "c": c, "reciprocal": int(reciprocal)
        }


    try:
        ts = time.strftime("%Y%m%d_%H%M%S")

        metrics = [
            ("loss",       "Training Loss", "loss"),
            ("valid_MRR",  "Valid MRR",     "MRR"),
            ("H@1",        "Valid Hits@1",  "H@1"),
            ("H@3",        "Valid Hits@3",  "H@3"),
            ("H@10",       "Valid Hits@10", "H@10"),
        ]

        for key, title, ylabel in metrics:
            fig, ax = plt.subplots(figsize=(8, 4))

            any_plotted = False
            for name, hist in curves.items():
                if key not in hist or len(hist[key]) == 0:
                    continue

                y = np.asarray(hist[key], dtype=float)
                x = np.arange(1, len(y) + 1)


                line, = ax.plot(x, y, label=name)
                color = line.get_color()


                if key.lower().startswith("loss"):
                    best_idx = int(np.argmin(y))
                else:
                    best_idx = int(np.argmax(y))

                best_x = x[best_idx]
                best_y = y[best_idx]


                ax.axhline(best_y, linestyle="--", color=color, alpha=0.4)


                ax.scatter(best_x, best_y, color=color, s=25, zorder=3)

                any_plotted = True

            if not any_plotted:
                plt.close(fig)
                continue

            ax.set_title(title)
            ax.set_ylabel(ylabel)
            ax.set_xlabel("epoch")
            ax.grid(True, linestyle="--", alpha=0.3)
            ax.legend()
            fig.tight_layout()

            fp = os.path.join(out_dir, f"curves_{key}_{ts}.png")
            plt.savefig(fp, dpi=200)
            print(f"Saved {key} curves to:", fp)

            plt.close(fig)

    except Exception as e:
        print("Plot skipped:", e)


    summary_path = os.path.join(out_dir, "summary.json")
    record = {"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "summaries": all_summaries}
    old=[]
    if os.path.exists(summary_path):
        try:
            with open(summary_path, "r", encoding="utf-8") as f:
                old = json.load(f)
                if isinstance(old, dict): old=[old]
        except json.JSONDecodeError:
            old=[]
    old.append(record)
    tmp = summary_path + ".tmp"
    with open(tmp, "w", encoding="utf-8") as f:
        json.dump(old, f, indent=2, ensure_ascii=False)
    os.replace(tmp, summary_path)
    print("Appended summary ->", summary_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, default="data")
    parser.add_argument("--out_dir", type=str, default="runs_wn18rr")
    parser.add_argument("--baselines", type=str, default="HAMN")
    parser.add_argument("--dim", type=int, default=200)
    parser.add_argument("--c", type=float, default=0.8)
    parser.add_argument("--tau", type=float, default=1.0)
    parser.add_argument("--K", type=int, default=16)
    parser.add_argument("--heads", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=1024)
    parser.add_argument("--neg_k", type=int, default=50)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--lr", type=float, default=2e-3)
    parser.add_argument("--wd", type=float, default=1e-4)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--reciprocal", type=int, default=1)
    parser.add_argument("--eval_batch", type=int, default=256)
    parser.add_argument("--eval_amp", type=int, default=1)
    parser.add_argument("--score_chunk", type=int, default=2048)
    args = parser.parse_args()

    bases = tuple([b.strip() for b in args.baselines.split(",") if b.strip()])
    main(
        data_root=args.data_root,
        out_dir=args.out_dir,
        baselines=bases,
        dim=args.dim, c=args.c, tau=args.tau, K=args.K, heads=args.heads,
        batch_size=args.batch_size, neg_k=args.neg_k, epochs=args.epochs,
        lr=args.lr, wd=args.wd, seed=args.seed,
        reciprocal=bool(args.reciprocal),
        eval_batch=args.eval_batch, eval_amp=bool(args.eval_amp), score_chunk=args.score_chunk
    )
