# ssdm_local_ou_teacher.py
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import time


# -----------------------------
# Config
# -----------------------------
@dataclass
class CFG:
    n_qubits: int = 6
    T: float = 1.0
    K: int = 500
    dt: float = 1.0 / 500
    delta: float = 1.0 / 500
    batch_size: int = 64
    lr: float = 2e-4
    steps: int = 2000
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # OU/VP schedules
    sigma_min: float = 0.05
    sigma_max: float = 1.0
    lambda0: float = 0.2  # mean reversion strength

    # toy data
    eps_cluster: float = 0.06

cfg = CFG()

# -----------------------------
# Complex helpers
# -----------------------------
def cinner(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """<a|b> for batched vectors: (B, d) complex -> (B,) complex"""
    return (a.conj() * b).sum(dim=-1)

def normalize(psi: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    nrm = torch.sqrt((psi.conj() * psi).sum(dim=-1).real + eps)
    return psi / nrm.unsqueeze(-1)

def tangent_project(psi: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Project v to tangent at psi: v - psi <psi|v>"""
    overlap = cinner(psi, v)
    return v - psi * overlap.unsqueeze(-1)

def random_pure_states(batch: int, d: int, device: str) -> torch.Tensor:
    re = torch.randn(batch, d, device=device)
    im = torch.randn(batch, d, device=device)
    psi = torch.complex(re, im)
    return normalize(psi)

# -----------------------------

# =============================
# QML downstream utilities
# =============================

@torch.no_grad()
def overlap_kernel(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    """
    K_ij = |<x_i|y_j>|^2
    X: (N,d), Y:(M,d)
    """
    return (X @ Y.conj().T).abs().pow(2).real


@torch.no_grad()
@torch.no_grad()
def kernel_alignment(K: torch.Tensor, y: torch.Tensor) -> float:
    """
    Centered alignment between kernel K and label kernel yy^T.
    """
    n = K.shape[0]
    H = torch.eye(n, device=K.device) - torch.ones(n, n, device=K.device) / n
    Kc = H @ K @ H
    yy = torch.outer(y, y)
    yyc = H @ yy @ H
    num = (Kc * yyc).sum()
    den = torch.norm(Kc) * torch.norm(yyc)
    return (num / den).item()


def kernel_ridge_train(K: torch.Tensor, y: torch.Tensor, lam: float = 1e-3):
    """
    Closed-form kernel ridge: alpha = (K + lam I)^{-1} y
    """
    n = K.shape[0]
    I = torch.eye(n, device=K.device)
    alpha = torch.linalg.solve(K + lam * I, y)
    return alpha


@torch.no_grad()
def kernel_ridge_predict(K_test: torch.Tensor, alpha: torch.Tensor):
    """
    f = K_test alpha
    """
    return K_test @ alpha


# Schedules
# -----------------------------
def sigma_t(t: torch.Tensor) -> torch.Tensor:
    # exponential schedule: sigma_min -> sigma_max
    s0, s1 = cfg.sigma_min, cfg.sigma_max
    return s0 * (s1 / s0) ** (t / cfg.T)

def lambda_t(t: torch.Tensor) -> torch.Tensor:
    return torch.full_like(t, cfg.lambda0)

def beta2(t: torch.Tensor, dt_small: float) -> torch.Tensor:
    """
    For VP-style local step, variance ~ ∫ sigma^2 ds ≈ sigma(t)^2 dt_small.
    This plays the role of beta(t,dt)^2 in the analytic OU/VP teacher.
    """
    s = sigma_t(t)
    return (s * s) * dt_small


import torch

# -------------------------
# Kernel: overlap |<psi|phi>|^2
# -------------------------
def overlap_kernel(X: torch.Tensor, Z: torch.Tensor) -> torch.Tensor:
    """
    X: (N,d) complex
    Z: (M,d) complex
    return: (N,M) real
    """
    # <x|z> = sum conj(x_i) * z_i
    inner = X.conj() @ Z.T
    K = inner.abs().pow(2)
    return K.real

def center_gram(K: torch.Tensor) -> torch.Tensor:
    """
    Center a Gram matrix: Kc = H K H, H = I - 11^T/n
    """
    n = K.shape[0]
    device = K.device
    H = torch.eye(n, device=device) - torch.ones((n, n), device=device) / n
    return H @ K @ H

# -------------------------
# 1) Kernel Alignment (Centered Kernel Target Alignment)
# -------------------------
def kernel_alignment(K: torch.Tensor, y: torch.Tensor, centered: bool = True, eps: float = 1e-12) -> float:
    """
    K: (N,N) Gram matrix (real)
    y: (N,) labels in {-1,+1}
    returns: scalar alignment in [-1,1] (typically)
    """
    y = y.float().view(-1, 1)
    Y = y @ y.T  # (N,N)

    if centered:
        Kc = center_gram(K)
        Yc = center_gram(Y)
    else:
        Kc, Yc = K, Y

    num = (Kc * Yc).sum()
    den = torch.sqrt((Kc * Kc).sum() * (Yc * Yc).sum() + eps)
    return (num / den).item()

# -------------------------
# 2) Class-conditional kernel gap
#    gap = mean(K_ii_sameclass) - mean(K_ij_diffclass)
# -------------------------
def class_conditional_kernel_gap(K: torch.Tensor, y: torch.Tensor) -> dict:
    """
    K: (N,N) Gram (real)
    y: (N,) labels {-1,+1}
    returns: dict with within/between/gap
    """
    y = y.view(-1)
    N = y.numel()
    same = (y[:, None] == y[None, :])
    diff = ~same

    # exclude diagonal from within-class if you want (optional)
    diag = torch.eye(N, device=K.device, dtype=torch.bool)
    same_wo_diag = same & ~diag

    within = K[same_wo_diag].mean().item() if same_wo_diag.any() else float("nan")
    between = K[diff].mean().item() if diff.any() else float("nan")
    gap = within - between
    return {"within_mean": within, "between_mean": between, "gap": gap}

# -------------------------
# 3) Margin (for Kernel Ridge Regression classifier)
#    Train: alpha = (K + lam I)^(-1) y
#    f(x_i)=sum_j K_ij alpha_j
#    ||f||_H^2 = alpha^T K alpha
#    normalized margin_i = y_i f(x_i) / ||f||_H
# -------------------------
def krr_train_alpha(Ktr: torch.Tensor, ytr: torch.Tensor, lam: float = 1e-3) -> torch.Tensor:
    """
    Ktr: (N,N) real
    ytr: (N,) in {-1,+1}
    returns alpha: (N,)
    """
    N = Ktr.shape[0]
    ytr = ytr.float()
    A = Ktr + lam * torch.eye(N, device=Ktr.device)
    alpha = torch.linalg.solve(A, ytr)
    return alpha

def krr_predict(Kte: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
    """
    Kte: (M,N) real, kernel between test and train
    alpha: (N,)
    returns f: (M,)
    """
    return Kte @ alpha

def krr_margins(Ktr: torch.Tensor, ytr: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-12) -> dict:
    """
    Compute normalized margins on training points.
    """
    f_tr = Ktr @ alpha  # (N,)
    # RKHS norm of predictor
    norm2 = (alpha * (Ktr @ alpha)).sum().clamp_min(eps)
    norm = torch.sqrt(norm2)
    m = (ytr.float() * f_tr) / norm  # normalized margin
    return {
        "margin_mean": m.mean().item(),
        "margin_min": m.min().item(),
        "margin_q10": m.quantile(0.10).item(),
        "rkhs_norm": norm.item()
    }

# -------------------------
# 4) MMD^2 with overlap kernel
#    Biased estimator:
#    MMD^2 = mean(Kxx) + mean(Kyy) - 2 mean(Kxy)
# -------------------------
def mmd2_biased(X: torch.Tensor, Y: torch.Tensor) -> float:
    """
    X: (N,d) complex
    Y: (M,d) complex
    returns MMD^2 (biased)
    """
    Kxx = overlap_kernel(X, X).mean()
    Kyy = overlap_kernel(Y, Y).mean()
    Kxy = overlap_kernel(X, Y).mean()
    return (Kxx + Kyy - 2 * Kxy).item()

# -------------------------
# One-stop evaluation helper
# -------------------------
@torch.no_grad()
def eval_rep_metrics(Xtr: torch.Tensor, ytr: torch.Tensor, Xgen: torch.Tensor, lam: float = 1e-3) -> dict:
    """
    Xtr: training states (N,d)
    ytr: labels (N,)
    Xgen: generated/augmented states (M,d)
    lam: KRR regularization for margin computation (on train)
    """
    Ktr = overlap_kernel(Xtr, Xtr)
    align = kernel_alignment(Ktr, ytr, centered=True)
    gap = class_conditional_kernel_gap(Ktr, ytr)

    alpha = krr_train_alpha(Ktr, ytr, lam=lam)
    margins = krr_margins(Ktr, ytr, alpha)

    mmd2 = mmd2_biased(Xtr, Xgen)

    out = {
        "KernelAlign": align,
        "KernelGap_within": gap["within_mean"],
        "KernelGap_between": gap["between_mean"],
        "KernelGap": gap["gap"],
        "Margin_mean": margins["margin_mean"],
        "Margin_min": margins["margin_min"],
        "Margin_q10": margins["margin_q10"],
        "MMD2(train,gen)": mmd2,
    }
    return out

# -----------------------------
# Toy data: cluster around |0...0>



# =============================

# =============================
# Build dataset
# =============================
@torch.no_grad()
@torch.no_grad()
@torch.no_grad()
def make_two_class_data(
    n_per_class: int,
    n_qubits: int,
    eps: float,
    shift: float,
    device: str,
):
    """
    Harder 2-class data: both classes near |0...0>, separated by bias on basis index j=1.
    """
    d = 2 ** n_qubits

    base = torch.zeros(n_per_class, d, device=device, dtype=torch.complex64)
    base[:, 0] = 1.0 + 0j
    j = 1  # |...0001>

    def sample_class(sign: float):
        re = torch.randn(n_per_class, d, device=device) * eps
        im = torch.randn(n_per_class, d, device=device) * eps
        noise = torch.complex(re, im)
        psi = base + noise
        psi[:, j] += (sign * shift) + 0j
        return normalize(psi)

    X0 = sample_class(-1.0)
    y0 = torch.full((n_per_class,), -1.0, device=device)

    X1 = sample_class(+1.0)
    y1 = torch.full((n_per_class,), +1.0, device=device)

    return torch.cat([X0, X1], dim=0), torch.cat([y0, y1], dim=0)

# =============================
# RQ4: representation-level augmentation experiment
# =============================
@torch.no_grad()
def qml_augmentation_experiment(model, n_qubits: int):
    device = cfg.device

    # data sizes
    n_train = 100
    n_test = 200
    aug_ratio = 1.0   # generate same amount as original

    # build datasets
    Xtr, ytr = make_two_class_data(n_train, n_qubits, eps=0.10, shift=0.07, device=device)
    Xte, yte = make_two_class_data(n_test,  n_qubits, eps=0.10, shift=0.07, device=device)


    # ---------- Original only ----------
    Ktr = overlap_kernel(Xtr, Xtr)
    Kte = overlap_kernel(Xte, Xtr)

    alpha = kernel_ridge_train(Ktr, ytr, lam=1e-3)
    pred = kernel_ridge_predict(Kte, alpha)
    acc_orig = (pred.sign() == yte).float().mean().item()
    align_orig = kernel_alignment(Ktr, ytr)

    # ---------- + SSDM augmentation ----------
    n_aug = int(aug_ratio * Xtr.shape[0])
    Xaug = sample(model, batch=n_aug)

    # class-anchored labeling: half to each class
    #yaug = torch.cat([
        #torch.full((n_aug // 2,), -1.0, device=device),
        #torch.full((n_aug - n_aug // 2,), +1.0, device=device),
    #])
    # --- pseudo-label by prototype similarity (overlap kernel)
    with torch.no_grad():
        X0 = Xtr[ytr < 0]
        X1 = Xtr[ytr > 0]
    
        # class prototypes (normalized mean state) - simple but effective
        proto0 = normalize(X0.mean(dim=0, keepdim=True))
        proto1 = normalize(X1.mean(dim=0, keepdim=True))
    
        s0 = (cinner(proto0, Xaug).abs() ** 2)  # (n_aug,)
        s1 = (cinner(proto1, Xaug).abs() ** 2)
    
        yaug = torch.where(s1 > s0, torch.tensor(+1.0, device=device), torch.tensor(-1.0, device=device))

    # optional: keep only confident pseudo-labels
    conf = (s1 - s0).abs()
    keep = conf > conf.quantile(0.5)   # keep top 50% confident
    Xaug, yaug = Xaug[keep], yaug[keep]


    Xtr_aug = torch.cat([Xtr, Xaug], dim=0)
    ytr_aug = torch.cat([ytr, yaug], dim=0)

    Ktr_aug = overlap_kernel(Xtr_aug, Xtr_aug)
    Kte_aug = overlap_kernel(Xte, Xtr_aug)

    alpha_aug = kernel_ridge_train(Ktr_aug, ytr_aug, lam=1e-3)
    pred_aug = kernel_ridge_predict(Kte_aug, alpha_aug)
    acc_aug = (pred_aug.sign() == yte).float().mean().item()
    align_aug = kernel_alignment(Ktr_aug, ytr_aug)

    print("RQ4 results")
    print(f"Original only     | Test Acc = {acc_orig:.4f} | Kernel Align = {align_orig:.4f}")
    print(f"+ SSDM augment    | Test Acc = {acc_aug:.4f} | Kernel Align = {align_aug:.4f}")

    
    # ---------- Representation-level metrics ----------
    metrics_orig = eval_rep_metrics(
        Xtr=Xtr,
        ytr=ytr,
        Xgen=Xtr,      # baseline: data vs data
        lam=1e-3
    )
    
    metrics_aug = eval_rep_metrics(
        Xtr=Xtr_aug,
        ytr=ytr_aug,
        Xgen=Xaug,     # data vs generated
        lam=1e-3
    )
    
    print("Representation metrics (train space)")
    print("Original:")
    for k, v in metrics_orig.items():
        print(f"  {k:22s}: {v:.4e}")
    
    print("With SSDM augmentation:")
    for k, v in metrics_aug.items():
        print(f"  {k:22s}: {v:.4e}")


    return acc_orig, align_orig, acc_aug, align_aug

# -----------------------------
def sample_cluster_states(batch: int, n_qubits: int, eps: float, device: str) -> torch.Tensor:
    d = 2 ** n_qubits
    base = torch.zeros(batch, d, device=device, dtype=torch.complex64)
    base[:, 0] = 1.0 + 0j
    re = torch.randn(batch, d, device=device) * eps
    im = torch.randn(batch, d, device=device) * eps
    noise = torch.complex(re, im)
    psi = base + noise
    return normalize(psi)

# -----------------------------
# Forward diffusion step (OU-like in tangent + noise)
# -----------------------------
def forward_step(psi: torch.Tensor, t: torch.Tensor, dt: float, anchor: torch.Tensor) -> torch.Tensor:
    """
    Intrinsic-ish OU-like step implemented by:
    - drift: tangent projection of (anchor - psi)
    - isotropic tangent noise
    - retract by normalization
    """
    B, d = psi.shape
    lam = lambda_t(t)  # (B,)
    sig = sigma_t(t)   # (B,)

    drift_raw = anchor - psi
    drift = tangent_project(psi, drift_raw)

    re = torch.randn(B, d, device=psi.device)
    im = torch.randn(B, d, device=psi.device)
    noise = torch.complex(re, im)
    noise = tangent_project(psi, noise)

    psi_next = psi + (-lam.unsqueeze(-1) * drift) * dt + (sig.unsqueeze(-1) * math.sqrt(dt)) * noise
    return normalize(psi_next)

# -----------------------------
# Local OU teacher (FS normal-coordinate approximation, MVP)
# -----------------------------
def log_map_approx(phi: torch.Tensor, psi: torch.Tensor) -> torch.Tensor:
    """
    MVP approximation of FS normal coordinates:
    z = log_phi(psi) ≈ P_phi(psi - phi)
    This is valid for small local steps.
    """
    return tangent_project(phi, psi - phi)

def transport_phi_to_psi_mvp(phi: torch.Tensor, psi: torch.Tensor, v_phi: torch.Tensor) -> torch.Tensor:
    """
    MVP transport approximation (not exact parallel transport):
    just re-project v_phi onto T_psi.
    """
    return tangent_project(psi, v_phi)

def teacher_score_local_ou(
    psi_t: torch.Tensor, psi_prev: torch.Tensor, t: torch.Tensor, dt_small: float, eps: float = 1e-6
) -> torch.Tensor:
    """
    Analytic teacher based on local Euclidean OU/VP approximation in FS normal coords.

    Steps:
      1) z = log_{psi_prev}(psi_t) ≈ P_{psi_prev}(psi_t - psi_prev)
      2) score in tangent coords: s_z = - z / beta^2
      3) map back to T_{psi_t} by MVP transport (projection)
    """
    phi = psi_prev
    psi = psi_t

    z = log_map_approx(phi, psi)                         # in T_phi
    b2 = beta2(t, dt_small).clamp_min(eps)               # scalar per batch
    s_phi = - z / b2.unsqueeze(-1)                       # teacher in T_phi
    s_psi = transport_phi_to_psi_mvp(phi, psi, s_phi)    # in T_psi
    return s_psi

# -----------------------------
# Time embedding
# -----------------------------
class TimeEmbedding(nn.Module):
    def __init__(self, dim: int = 128):
        super().__init__()
        self.dim = dim
        self.lin = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half = self.dim // 2
        freqs = torch.exp(torch.linspace(math.log(1.0), math.log(1000.0), half, device=t.device))
        x = t.unsqueeze(-1) * freqs.unsqueeze(0) * 2 * math.pi
        emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
        return self.lin(emb)

# -----------------------------
# Score Network (classical)
# -----------------------------
class ScoreNet(nn.Module):
    def __init__(self, d: int, tdim: int = 128, width: int = 512, depth: int = 5):
        super().__init__()
        self.d = d
        self.tenc = TimeEmbedding(tdim)

        in_dim = 2 * d + tdim
        layers = [nn.Linear(in_dim, width), nn.SiLU()]
        for _ in range(depth - 1):
            layers += [nn.Linear(width, width), nn.SiLU()]
        layers += [nn.Linear(width, 2 * d)]
        self.net = nn.Sequential(*layers)

    def forward(self, psi: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x = torch.cat([psi.real, psi.imag], dim=-1)
        te = self.tenc(t)
        h = torch.cat([x, te], dim=-1)
        out = self.net(h)
        u = torch.complex(out[:, : self.d], out[:, self.d :])
        return tangent_project(psi, u)

# -----------------------------
# MMD overlap kernel (same as your MVP)
# -----------------------------
@torch.no_grad()
def _pairwise_fidelity_kernel_linear(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    inner = X @ Y.conj().T
    return inner.abs().pow(2).real
@torch.no_grad()
def mmd2_overlap(
    X: torch.Tensor,
    Y: torch.Tensor,
    gamma: float = 5.0,
    unbiased: bool = True,
    eps: float = 1e-12,
) -> float:
    """
    MMD^2 using overlap kernel exp(gamma * fidelity).
    X, Y are complex normalized pure states.
    """
    assert X.dtype.is_complex and Y.dtype.is_complex
    X = X / (X.abs().pow(2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps))
    Y = Y / (Y.abs().pow(2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps))

    Kxx = _pairwise_fidelity_kernel_linear(X, X)
    Kyy = _pairwise_fidelity_kernel_linear(Y, Y)
    Kxy = _pairwise_fidelity_kernel_linear(X, Y)

    n = X.shape[0]
    m = Y.shape[0]

    if unbiased:
        # remove diagonal terms to reduce bias
        if n > 1:
            Kxx_sum = (Kxx.sum() - Kxx.diag().sum()) / (n * (n - 1))
        else:
            Kxx_sum = Kxx.mean()
        if m > 1:
            Kyy_sum = (Kyy.sum() - Kyy.diag().sum()) / (m * (m - 1))
        else:
            Kyy_sum = Kyy.mean()
    else:
        Kxx_sum = Kxx.mean()
        Kyy_sum = Kyy.mean()

    Kxy_sum = Kxy.mean()

    mmd2 = (Kxx_sum + Kyy_sum - 2.0 * Kxy_sum).item()
    return float(max(mmd2, 0.0))



# -----------------------------
# Observable mismatch Δ_obs
# We use {Z_i} and {Z_i Z_{i+1}} as a simple, hardware-friendly observable set.
# -----------------------------
_OBS_CACHE = {}

def _z_sign_vector(n_qubits: int, qubit: int, device: str):
    """
    Return a length-d vector with entries (+1/-1) corresponding to Z on 'qubit'.
    Bit convention: basis index |b_{n-1}...b_0>, where b_0 is LSB.
    """
    key = ("Z", n_qubits, qubit, device)
    if key in _OBS_CACHE:
        return _OBS_CACHE[key]
    d = 2 ** n_qubits
    idx = torch.arange(d, device=device)
    bits = (idx >> qubit) & 1
    sign = (1.0 - 2.0 * bits.float())  # 0->+1, 1->-1
    _OBS_CACHE[key] = sign
    return sign

def _zz_sign_vector(n_qubits: int, q1: int, q2: int, device: str):
    key = ("ZZ", n_qubits, q1, q2, device)
    if key in _OBS_CACHE:
        return _OBS_CACHE[key]
    s1 = _z_sign_vector(n_qubits, q1, device)
    s2 = _z_sign_vector(n_qubits, q2, device)
    sign = s1 * s2
    _OBS_CACHE[key] = sign
    return sign

@torch.no_grad()
def observable_expectations_Z_ZZ(psi: torch.Tensor, n_qubits: int) -> torch.Tensor:
    """
    psi: (B, d) complex
    return: (B, J) real, where J = n + (n-1) for Z_i and Z_i Z_{i+1}
    """
    device = psi.device
    B, d = psi.shape
    prob = (psi.conj() * psi).sum(dim=-1)  # (B,) complex? wait; wrong
    # correct probability per basis:
    prob = (psi.conj() * psi).real  # (B, d)

    feats = []
    # Z_i
    for i in range(n_qubits):
        sign = _z_sign_vector(n_qubits, i, device).unsqueeze(0)  # (1, d)
        feats.append((prob * sign).sum(dim=-1))  # (B,)
    # ZZ_{i,i+1}
    for i in range(n_qubits - 1):
        sign = _zz_sign_vector(n_qubits, i, i + 1, device).unsqueeze(0)
        feats.append((prob * sign).sum(dim=-1))
    return torch.stack(feats, dim=-1)  # (B, J)

@torch.no_grad()
def delta_obs(X: torch.Tensor, Y: torch.Tensor, n_qubits: int) -> float:
    """
    Δ_obs = mean_j | E_X[<O_j>] - E_Y[<O_j>] |
    """
    ex = observable_expectations_Z_ZZ(X, n_qubits).mean(dim=0)  # (J,)
    ey = observable_expectations_Z_ZZ(Y, n_qubits).mean(dim=0)  # (J,)
    return (ex - ey).abs().mean().item()




# -----------------------------
# Entanglement stats & Wasserstein-1 (1D)
# We use von Neumann entropy S(ρ_A) for A = first n//2 qubits by default.
# -----------------------------
@torch.no_grad()
def entanglement_entropy_vn(psi: torch.Tensor, n_qubits: int, nA: int = None, eps: float = 1e-12) -> torch.Tensor:
    """
    Compute von Neumann entropy of subsystem A for a batch of pure states.
    psi: (B, 2^n) complex
    returns: (B,) float
    """
    if nA is None:
        nA = n_qubits // 2
    dA = 2 ** nA
    dB = 2 ** (n_qubits - nA)

    Bsz = psi.shape[0]
    psi_rs = psi.view(Bsz, dA, dB)  # (B, dA, dB)

    # ρ_A = Tr_B |ψ><ψ|,  ρ_{ik} = sum_j ψ_{ij} ψ*_{kj}
    rhoA = torch.einsum("bij,bkj->bik", psi_rs, psi_rs.conj())  # (B, dA, dA)
    # ensure Hermitian numerically
    rhoA = 0.5 * (rhoA + rhoA.conj().transpose(-1, -2))

    # eigenvalues (real, >=0)
    evals = torch.linalg.eigvalsh(rhoA).real.clamp_min(eps)  # (B, dA)
    # normalize (numerical safety)
    evals = evals / evals.sum(dim=-1, keepdim=True).clamp_min(eps)

    S = -(evals * torch.log(evals)).sum(dim=-1)  # natural log
    return S

@torch.no_grad()
def wasserstein1_1d(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    W1 for 1D empirical distributions with equal sample size:
    W1 = mean_i |sort(a)_i - sort(b)_i|
    """
    a = a.flatten()
    b = b.flatten()
    n = min(a.numel(), b.numel())
    if n == 0:
        return 0.0
    a = torch.sort(a[:n])[0]
    b = torch.sort(b[:n])[0]
    return (a - b).abs().mean().item()

@torch.no_grad()
def ent_w1(X: torch.Tensor, Y: torch.Tensor, n_qubits: int, nA: int = None) -> float:
    sx = entanglement_entropy_vn(X, n_qubits, nA=nA)
    sy = entanglement_entropy_vn(Y, n_qubits, nA=nA)
    return wasserstein1_1d(sx, sy)

    

# -----------------------------
# Reverse-time sampling
# -----------------------------
@torch.no_grad()
def sample(model: nn.Module, batch: int = 256) -> torch.Tensor:
    device = cfg.device
    n = cfg.n_qubits
    d = 2 ** n

    # prior p_T
    psi = random_pure_states(batch, d, device)

    anchor = torch.zeros(batch, d, device=device, dtype=torch.complex64)
    anchor[:, 0] = 1.0 + 0j

    times = torch.linspace(cfg.T, 0.0, cfg.K + 1, device=device)
    for k in range(cfg.K, 0, -1):
        t = times[k].expand(batch)
        dt = (times[k] - times[k - 1])      # negative
        dt_abs = (-dt).clamp_min(1e-12)

        sig = sigma_t(t)
        lam = lambda_t(t)

        drift_raw = anchor - psi
        b = -lam.unsqueeze(-1) * tangent_project(psi, drift_raw)

        s = model(psi, t)

        re = torch.randn(batch, d, device=device)
        im = torch.randn(batch, d, device=device)
        noise = torch.complex(re, im)
        noise = tangent_project(psi, noise)

        psi = psi + (b - (sig * sig).unsqueeze(-1) * s) * dt + sig.unsqueeze(-1) * torch.sqrt(dt_abs) * noise
        psi = normalize(psi)

    return psi

# -----------------------------
# Training loop
# -----------------------------
def train():
    torch.manual_seed(0)
    device = cfg.device
    n = cfg.n_qubits
    d = 2 ** n

    model = ScoreNet(d=d).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)

    anchor = torch.zeros(cfg.batch_size, d, device=device, dtype=torch.complex64)
    anchor[:, 0] = 1.0 + 0j

    DATA_POOL_SIZE = 4096
    DATA_POOL = sample_cluster_states(DATA_POOL_SIZE, n, cfg.eps_cluster, device)

    def sample_from_pool(k: int) -> torch.Tensor:
        idx = torch.randint(0, DATA_POOL.shape[0], (k,), device=device)
        return DATA_POOL[idx]

    for step in range(1, cfg.steps + 1):
        psi0 = sample_cluster_states(cfg.batch_size, n, cfg.eps_cluster, device)

        t = (cfg.delta + (cfg.T - cfg.delta) * torch.rand(cfg.batch_size, device=device))

        # crude forward simulate to time t (like your MVP)
        psi = psi0
        num_steps = max(1, int((t.mean().item()) / cfg.dt))
        cur_t = torch.zeros(cfg.batch_size, device=device)
        for _ in range(num_steps):
            psi = forward_step(psi, cur_t, cfg.dt, anchor)
            cur_t = (cur_t + cfg.dt).clamp_max(cfg.T)

        psi_prev = psi
        psi_t = forward_step(psi_prev, t - cfg.delta, cfg.delta, anchor)

        # ---- analytic local OU teacher
        with torch.no_grad():
            s_teach = teacher_score_local_ou(psi_t, psi_prev, t, cfg.delta)

        s_pred = model(psi_t, t)

        # VP-consistent weighting: lambda = beta^2
        w = beta2(t, cfg.delta).detach().clamp_min(1e-6)
        diff = s_pred - s_teach
        loss = (w.unsqueeze(-1) * (diff.conj() * diff).real).sum(dim=-1).mean()

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        if step % 200 == 0:
            X0 = sample_from_pool(256)
            Xgen = sample(model, batch=256)
            Xhaar = random_pure_states(256, d, device)

            mmd_gen = mmd2_overlap(X0, Xgen, gamma=5.0, unbiased=True)
            mmd_haar = mmd2_overlap(X0, Xhaar, gamma=5.0, unbiased=True)

            mmd_dd_list = []
            for _ in range(5):
                Xa = sample_from_pool(256)
                Xb = sample_from_pool(256)
                mmd_dd_list.append(mmd2_overlap(Xa, Xb, gamma=5.0, unbiased=True))
            mmd_dd = sum(mmd_dd_list) / len(mmd_dd_list)

            center = torch.zeros(256, d, device=device, dtype=torch.complex64)
            center[:, 0] = 1.0 + 0j
            fid = (cinner(center, Xgen).abs() ** 2).mean().item()


                        # ---- NEW: Δ_obs and Ent. W1
            dob = delta_obs(X0, Xgen, n_qubits=n)
            ew1 = ent_w1(X0, Xgen, n_qubits=n, nA=n // 2)

            # optional: data-data "floor" (helps interpret)
            Xa = sample_from_pool(256)
            Xb = sample_from_pool(256)
            dob_dd = delta_obs(Xa, Xb, n_qubits=n)
            ew1_dd = ent_w1(Xa, Xb, n_qubits=n, nA=n // 2)


            print(
                f"step {step:5d} | loss {loss.item():.4e} | F0 {fid:.4f} | "
                f"MMD(data,gen) {mmd_gen:.4e} | MMD(data,haar) {mmd_haar:.4e} | "
                f"MMD(data,data) {mmd_dd:.4e} | "
                f"Δobs(data,gen) {dob:.4e} | Δobs(data,data) {dob_dd:.4e} | "
                f"EntW1(data,gen) {ew1:.4e} | EntW1(data,data) {ew1_dd:.4e}"
            )

    return model

if __name__ == "__main__":
    model = train()
    qml_augmentation_experiment(model, cfg.n_qubits)

    

