# ssdm_local_ou_teacher.py
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import time


# -----------------------------
# Config
# -----------------------------
@dataclass
class CFG:
    n_qubits: int = 6
    T: float = 1.0
    K: int = 500
    dt: float = 1.0 / 500
    delta: float = 1.0 / 500
    batch_size: int = 64
    lr: float = 2e-4
    steps: int = 10000
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # OU/VP schedules
    sigma_min: float = 0.05
    sigma_max: float = 1.0
    lambda0: float = 0.2  # mean reversion strength

    # toy data
    eps_cluster: float = 0.06

cfg = CFG()

# -----------------------------
# Complex helpers
# -----------------------------
def cinner(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """<a|b> for batched vectors: (B, d) complex -> (B,) complex"""
    return (a.conj() * b).sum(dim=-1)

def normalize(psi: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    nrm = torch.sqrt((psi.conj() * psi).sum(dim=-1).real + eps)
    return psi / nrm.unsqueeze(-1)

def tangent_project(psi: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """Project v to tangent at psi: v - psi <psi|v>"""
    overlap = cinner(psi, v)
    return v - psi * overlap.unsqueeze(-1)

def random_pure_states(batch: int, d: int, device: str) -> torch.Tensor:
    re = torch.randn(batch, d, device=device)
    im = torch.randn(batch, d, device=device)
    psi = torch.complex(re, im)
    return normalize(psi)

# -----------------------------
# Schedules
# -----------------------------
def sigma_t(t: torch.Tensor) -> torch.Tensor:
    # exponential schedule: sigma_min -> sigma_max
    s0, s1 = cfg.sigma_min, cfg.sigma_max
    return s0 * (s1 / s0) ** (t / cfg.T)

def lambda_t(t: torch.Tensor) -> torch.Tensor:
    return torch.full_like(t, cfg.lambda0)

def beta2(t: torch.Tensor, dt_small: float) -> torch.Tensor:
    """
    For VP-style local step, variance ~ ∫ sigma^2 ds ≈ sigma(t)^2 dt_small.
    This plays the role of beta(t,dt)^2 in the analytic OU/VP teacher.
    """
    s = sigma_t(t)
    return (s * s) * dt_small

# -----------------------------
# Toy data: cluster around |0...0>
# -----------------------------
def sample_cluster_states(batch: int, n_qubits: int, eps: float, device: str) -> torch.Tensor:
    d = 2 ** n_qubits
    base = torch.zeros(batch, d, device=device, dtype=torch.complex64)
    base[:, 0] = 1.0 + 0j
    re = torch.randn(batch, d, device=device) * eps
    im = torch.randn(batch, d, device=device) * eps
    noise = torch.complex(re, im)
    psi = base + noise
    return normalize(psi)

# -----------------------------
# Forward diffusion step (OU-like in tangent + noise)
# -----------------------------
def forward_step(psi: torch.Tensor, t: torch.Tensor, dt: float, anchor: torch.Tensor) -> torch.Tensor:
    """
    Intrinsic-ish OU-like step implemented by:
    - drift: tangent projection of (anchor - psi)
    - isotropic tangent noise
    - retract by normalization
    """
    B, d = psi.shape
    lam = lambda_t(t)  # (B,)
    sig = sigma_t(t)   # (B,)

    drift_raw = anchor - psi
    drift = tangent_project(psi, drift_raw)

    re = torch.randn(B, d, device=psi.device)
    im = torch.randn(B, d, device=psi.device)
    noise = torch.complex(re, im)
    noise = tangent_project(psi, noise)

    psi_next = psi + (-lam.unsqueeze(-1) * drift) * dt + (sig.unsqueeze(-1) * math.sqrt(dt)) * noise
    return normalize(psi_next)

# -----------------------------
# Local OU teacher (FS normal-coordinate approximation, MVP)
# -----------------------------
def log_map_approx(phi: torch.Tensor, psi: torch.Tensor) -> torch.Tensor:
    """
    MVP approximation of FS normal coordinates:
    z = log_phi(psi) ≈ P_phi(psi - phi)
    This is valid for small local steps.
    """
    return tangent_project(phi, psi - phi)

def transport_phi_to_psi_mvp(phi: torch.Tensor, psi: torch.Tensor, v_phi: torch.Tensor) -> torch.Tensor:
    """
    MVP transport approximation (not exact parallel transport):
    just re-project v_phi onto T_psi.
    """
    return tangent_project(psi, v_phi)

def teacher_score_local_ou(
    psi_t: torch.Tensor, psi_prev: torch.Tensor, t: torch.Tensor, dt_small: float, eps: float = 1e-6
) -> torch.Tensor:
    """
    Analytic teacher based on local Euclidean OU/VP approximation in FS normal coords.

    Steps:
      1) z = log_{psi_prev}(psi_t) ≈ P_{psi_prev}(psi_t - psi_prev)
      2) score in tangent coords: s_z = - z / beta^2
      3) map back to T_{psi_t} by MVP transport (projection)
    """
    phi = psi_prev
    psi = psi_t

    z = log_map_approx(phi, psi)                         # in T_phi
    b2 = beta2(t, dt_small).clamp_min(eps)               # scalar per batch
    s_phi = - z / b2.unsqueeze(-1)                       # teacher in T_phi
    s_psi = transport_phi_to_psi_mvp(phi, psi, s_phi)    # in T_psi
    return s_psi

# -----------------------------
# Time embedding
# -----------------------------
class TimeEmbedding(nn.Module):
    def __init__(self, dim: int = 128):
        super().__init__()
        self.dim = dim
        self.lin = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
        )

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half = self.dim // 2
        freqs = torch.exp(torch.linspace(math.log(1.0), math.log(1000.0), half, device=t.device))
        x = t.unsqueeze(-1) * freqs.unsqueeze(0) * 2 * math.pi
        emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
        return self.lin(emb)

# -----------------------------
# Score Network (classical)
# -----------------------------
class ScoreNet(nn.Module):
    def __init__(self, d: int, tdim: int = 128, width: int = 512, depth: int = 5):
        super().__init__()
        self.d = d
        self.tenc = TimeEmbedding(tdim)

        in_dim = 2 * d + tdim
        layers = [nn.Linear(in_dim, width), nn.SiLU()]
        for _ in range(depth - 1):
            layers += [nn.Linear(width, width), nn.SiLU()]
        layers += [nn.Linear(width, 2 * d)]
        self.net = nn.Sequential(*layers)

    def forward(self, psi: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x = torch.cat([psi.real, psi.imag], dim=-1)
        te = self.tenc(t)
        h = torch.cat([x, te], dim=-1)
        out = self.net(h)
        u = torch.complex(out[:, : self.d], out[:, self.d :])
        return tangent_project(psi, u)

# -----------------------------
# MMD overlap kernel (same as your MVP)
# -----------------------------
@torch.no_grad()
def _pairwise_fidelity_kernel_linear(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    inner = X @ Y.conj().T
    return inner.abs().pow(2).real
@torch.no_grad()
def mmd2_overlap(
    X: torch.Tensor,
    Y: torch.Tensor,
    gamma: float = 5.0,
    unbiased: bool = True,
    eps: float = 1e-12,
) -> float:
    """
    MMD^2 using overlap kernel exp(gamma * fidelity).
    X, Y are complex normalized pure states.
    """
    assert X.dtype.is_complex and Y.dtype.is_complex
    X = X / (X.abs().pow(2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps))
    Y = Y / (Y.abs().pow(2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps))

    Kxx = _pairwise_fidelity_kernel_linear(X, X)
    Kyy = _pairwise_fidelity_kernel_linear(Y, Y)
    Kxy = _pairwise_fidelity_kernel_linear(X, Y)

    n = X.shape[0]
    m = Y.shape[0]

    if unbiased:
        # remove diagonal terms to reduce bias
        if n > 1:
            Kxx_sum = (Kxx.sum() - Kxx.diag().sum()) / (n * (n - 1))
        else:
            Kxx_sum = Kxx.mean()
        if m > 1:
            Kyy_sum = (Kyy.sum() - Kyy.diag().sum()) / (m * (m - 1))
        else:
            Kyy_sum = Kyy.mean()
    else:
        Kxx_sum = Kxx.mean()
        Kyy_sum = Kyy.mean()

    Kxy_sum = Kxy.mean()

    mmd2 = (Kxx_sum + Kyy_sum - 2.0 * Kxy_sum).item()
    return float(max(mmd2, 0.0))



# -----------------------------
# Observable mismatch Δ_obs
# We use {Z_i} and {Z_i Z_{i+1}} as a simple, hardware-friendly observable set.
# -----------------------------
_OBS_CACHE = {}

def _z_sign_vector(n_qubits: int, qubit: int, device: str):
    """
    Return a length-d vector with entries (+1/-1) corresponding to Z on 'qubit'.
    Bit convention: basis index |b_{n-1}...b_0>, where b_0 is LSB.
    """
    key = ("Z", n_qubits, qubit, device)
    if key in _OBS_CACHE:
        return _OBS_CACHE[key]
    d = 2 ** n_qubits
    idx = torch.arange(d, device=device)
    bits = (idx >> qubit) & 1
    sign = (1.0 - 2.0 * bits.float())  # 0->+1, 1->-1
    _OBS_CACHE[key] = sign
    return sign

def _zz_sign_vector(n_qubits: int, q1: int, q2: int, device: str):
    key = ("ZZ", n_qubits, q1, q2, device)
    if key in _OBS_CACHE:
        return _OBS_CACHE[key]
    s1 = _z_sign_vector(n_qubits, q1, device)
    s2 = _z_sign_vector(n_qubits, q2, device)
    sign = s1 * s2
    _OBS_CACHE[key] = sign
    return sign

@torch.no_grad()
def observable_expectations_Z_ZZ(psi: torch.Tensor, n_qubits: int) -> torch.Tensor:
    """
    psi: (B, d) complex
    return: (B, J) real, where J = n + (n-1) for Z_i and Z_i Z_{i+1}
    """
    device = psi.device
    B, d = psi.shape
    prob = (psi.conj() * psi).sum(dim=-1)  # (B,) complex? wait; wrong
    # correct probability per basis:
    prob = (psi.conj() * psi).real  # (B, d)

    feats = []
    # Z_i
    for i in range(n_qubits):
        sign = _z_sign_vector(n_qubits, i, device).unsqueeze(0)  # (1, d)
        feats.append((prob * sign).sum(dim=-1))  # (B,)
    # ZZ_{i,i+1}
    for i in range(n_qubits - 1):
        sign = _zz_sign_vector(n_qubits, i, i + 1, device).unsqueeze(0)
        feats.append((prob * sign).sum(dim=-1))
    return torch.stack(feats, dim=-1)  # (B, J)

@torch.no_grad()
def delta_obs(X: torch.Tensor, Y: torch.Tensor, n_qubits: int) -> float:
    """
    Δ_obs = mean_j | E_X[<O_j>] - E_Y[<O_j>] |
    """
    ex = observable_expectations_Z_ZZ(X, n_qubits).mean(dim=0)  # (J,)
    ey = observable_expectations_Z_ZZ(Y, n_qubits).mean(dim=0)  # (J,)
    return (ex - ey).abs().mean().item()




# -----------------------------
# Entanglement stats & Wasserstein-1 (1D)
# We use von Neumann entropy S(ρ_A) for A = first n//2 qubits by default.
# -----------------------------
@torch.no_grad()
def entanglement_entropy_vn(psi: torch.Tensor, n_qubits: int, nA: int = None, eps: float = 1e-12) -> torch.Tensor:
    """
    Compute von Neumann entropy of subsystem A for a batch of pure states.
    psi: (B, 2^n) complex
    returns: (B,) float
    """
    if nA is None:
        nA = n_qubits // 2
    dA = 2 ** nA
    dB = 2 ** (n_qubits - nA)

    Bsz = psi.shape[0]
    psi_rs = psi.view(Bsz, dA, dB)  # (B, dA, dB)

    # ρ_A = Tr_B |ψ><ψ|,  ρ_{ik} = sum_j ψ_{ij} ψ*_{kj}
    rhoA = torch.einsum("bij,bkj->bik", psi_rs, psi_rs.conj())  # (B, dA, dA)
    # ensure Hermitian numerically
    rhoA = 0.5 * (rhoA + rhoA.conj().transpose(-1, -2))

    # eigenvalues (real, >=0)
    evals = torch.linalg.eigvalsh(rhoA).real.clamp_min(eps)  # (B, dA)
    # normalize (numerical safety)
    evals = evals / evals.sum(dim=-1, keepdim=True).clamp_min(eps)

    S = -(evals * torch.log(evals)).sum(dim=-1)  # natural log
    return S

@torch.no_grad()
def wasserstein1_1d(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    W1 for 1D empirical distributions with equal sample size:
    W1 = mean_i |sort(a)_i - sort(b)_i|
    """
    a = a.flatten()
    b = b.flatten()
    n = min(a.numel(), b.numel())
    if n == 0:
        return 0.0
    a = torch.sort(a[:n])[0]
    b = torch.sort(b[:n])[0]
    return (a - b).abs().mean().item()

@torch.no_grad()
def ent_w1(X: torch.Tensor, Y: torch.Tensor, n_qubits: int, nA: int = None) -> float:
    sx = entanglement_entropy_vn(X, n_qubits, nA=nA)
    sy = entanglement_entropy_vn(Y, n_qubits, nA=nA)
    return wasserstein1_1d(sx, sy)

    

# -----------------------------
# Reverse-time sampling
# -----------------------------
@torch.no_grad()
def sample(model: nn.Module, batch: int = 256) -> torch.Tensor:
    device = cfg.device
    n = cfg.n_qubits
    d = 2 ** n

    # prior p_T
    psi = random_pure_states(batch, d, device)

    anchor = torch.zeros(batch, d, device=device, dtype=torch.complex64)
    anchor[:, 0] = 1.0 + 0j

    times = torch.linspace(cfg.T, 0.0, cfg.K + 1, device=device)
    for k in range(cfg.K, 0, -1):
        t = times[k].expand(batch)
        dt = (times[k] - times[k - 1])      # negative
        dt_abs = (-dt).clamp_min(1e-12)

        sig = sigma_t(t)
        lam = lambda_t(t)

        drift_raw = anchor - psi
        b = -lam.unsqueeze(-1) * tangent_project(psi, drift_raw)

        s = model(psi, t)

        re = torch.randn(batch, d, device=device)
        im = torch.randn(batch, d, device=device)
        noise = torch.complex(re, im)
        noise = tangent_project(psi, noise)

        psi = psi + (b - (sig * sig).unsqueeze(-1) * s) * dt + sig.unsqueeze(-1) * torch.sqrt(dt_abs) * noise
        psi = normalize(psi)

    return psi

# -----------------------------
# Training loop
# -----------------------------
def train():
    torch.manual_seed(0)
    device = cfg.device
    n = cfg.n_qubits
    d = 2 ** n

    model = ScoreNet(d=d).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)

    anchor = torch.zeros(cfg.batch_size, d, device=device, dtype=torch.complex64)
    anchor[:, 0] = 1.0 + 0j

    DATA_POOL_SIZE = 4096
    DATA_POOL = sample_cluster_states(DATA_POOL_SIZE, n, cfg.eps_cluster, device)

    def sample_from_pool(k: int) -> torch.Tensor:
        idx = torch.randint(0, DATA_POOL.shape[0], (k,), device=device)
        return DATA_POOL[idx]

    for step in range(1, cfg.steps + 1):
        psi0 = sample_cluster_states(cfg.batch_size, n, cfg.eps_cluster, device)

        t = (cfg.delta + (cfg.T - cfg.delta) * torch.rand(cfg.batch_size, device=device))

        # crude forward simulate to time t (like your MVP)
        psi = psi0
        num_steps = max(1, int((t.mean().item()) / cfg.dt))
        cur_t = torch.zeros(cfg.batch_size, device=device)
        for _ in range(num_steps):
            psi = forward_step(psi, cur_t, cfg.dt, anchor)
            cur_t = (cur_t + cfg.dt).clamp_max(cfg.T)

        psi_prev = psi
        psi_t = forward_step(psi_prev, t - cfg.delta, cfg.delta, anchor)

        # ---- analytic local OU teacher
        with torch.no_grad():
            s_teach = teacher_score_local_ou(psi_t, psi_prev, t, cfg.delta)

        s_pred = model(psi_t, t)

        # VP-consistent weighting: lambda = beta^2
        w = beta2(t, cfg.delta).detach().clamp_min(1e-6)
        diff = s_pred - s_teach
        loss = (w.unsqueeze(-1) * (diff.conj() * diff).real).sum(dim=-1).mean()

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        if step % 200 == 0:
            X0 = sample_from_pool(256)
            Xgen = sample(model, batch=256)
            Xhaar = random_pure_states(256, d, device)

            mmd_gen = mmd2_overlap(X0, Xgen, gamma=5.0, unbiased=True)
            mmd_haar = mmd2_overlap(X0, Xhaar, gamma=5.0, unbiased=True)

            mmd_dd_list = []
            for _ in range(5):
                Xa = sample_from_pool(256)
                Xb = sample_from_pool(256)
                mmd_dd_list.append(mmd2_overlap(Xa, Xb, gamma=5.0, unbiased=True))
            mmd_dd = sum(mmd_dd_list) / len(mmd_dd_list)

            center = torch.zeros(256, d, device=device, dtype=torch.complex64)
            center[:, 0] = 1.0 + 0j
            fid = (cinner(center, Xgen).abs() ** 2).mean().item()


                        # ---- NEW: Δ_obs and Ent. W1
            dob = delta_obs(X0, Xgen, n_qubits=n)
            ew1 = ent_w1(X0, Xgen, n_qubits=n, nA=n // 2)

            # optional: data-data "floor" (helps interpret)
            Xa = sample_from_pool(256)
            Xb = sample_from_pool(256)
            dob_dd = delta_obs(Xa, Xb, n_qubits=n)
            ew1_dd = ent_w1(Xa, Xb, n_qubits=n, nA=n // 2)


            print(
                f"step {step:5d} | loss {loss.item():.4e} | F0 {fid:.4f} | "
                f"MMD(data,gen) {mmd_gen:.4e} | MMD(data,haar) {mmd_haar:.4e} | "
                f"MMD(data,data) {mmd_dd:.4e} | "
                f"Δobs(data,gen) {dob:.4e} | Δobs(data,data) {dob_dd:.4e} | "
                f"EntW1(data,gen) {ew1:.4e} | EntW1(data,data) {ew1_dd:.4e}"
            )

    return model

if __name__ == "__main__":
    time1=time.time()
    model = train()
    _ = sample(model, batch=512)
    time2=time.time()
    print(time2-time1)
