import glob
import os
import json
import numpy as np
import torch
from torchdiffeq import odeint
from scipy.spatial import Delaunay

# ---------- Device & dtype ----------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE  = torch.float64  # you use double precision throughout

# ---------- Utilities ----------
def stationary_from_K(K: np.ndarray):
    # Computes the stationary distribution from K leveraging its characterization
    # as the max-abs-eigenvalue eigenvector (λ=1). CPU/NumPy helper.
    w, v = np.linalg.eig(K.T)
    i = np.argmin(np.abs(w - 1.0))
    pi = np.real(v[:, i])
    pi = np.maximum(pi, 0)
    pi = pi / pi.sum()
    return pi

def log_mean(a, b, eps=1e-16):
    # Λ(a, b) = (a - b) / (log a - log b), with limit = a if a ≈ b
    a = torch.clamp(a, min=eps); b = torch.clamp(b, min=eps)
    denom = torch.log(a) - torch.log(b)
    num = a - b
    return torch.where(torch.abs(denom) > 1e-12, num / denom, a)

def log_mean_torch(a, b, eps=1e-16):
    a = torch.clamp(a, min=eps)
    b = torch.clamp(b, min=eps)
    den = torch.log(a) - torch.log(b)
    out = torch.where(den.abs() > 1e-12, (a - b) / den, a)  # θ(a,a)=a
    return out

def grad_V(V):
    # Discrete gradient of potential V
    if isinstance(V, np.ndarray):
        V = torch.as_tensor(V, device=DEVICE, dtype=DTYPE)
    else:
        V = V.to(device=DEVICE, dtype=DTYPE)
    return V.unsqueeze(1) - V.unsqueeze(0)  # [n,n] with [i,j] = V[j]-V[i]

# ---------- FP RHS variants (device-aware) ----------
def fp_rhs_logmean(K, p, V, beta, eps=1e-16):
    p = torch.clamp(p.to(dtype=DTYPE), min=eps)
    logp = torch.log(p)
    G = beta * logp + V.to(dtype=DTYPE)         # [n]
    DeltaG = G.unsqueeze(1) - G.unsqueeze(0)    # [n,n]
    theta = log_mean_torch(p.unsqueeze(1), p.unsqueeze(0), eps=eps)  # [n,n]
    flux = K * theta * DeltaG                   # [n,n]
    return flux.sum(dim=1)

def heat_equation_rhs(K, p, V, beta, eps=1e-16):
    p = torch.clamp(p.to(dtype=DTYPE), min=eps)
    g = beta * p + V.to(dtype=DTYPE)
    return K @ g - g

def heat_equation_logmean_rhs(K, pi, p, V, beta, eps=1e-3):
    # All inputs expected on same device/dtype
    K = K.to(dtype=DTYPE)
    p = torch.clamp(p.to(dtype=DTYPE), min=eps)
    pi = pi.to(dtype=DTYPE)
    V  = V.to(dtype=DTYPE)

    # project to ⟨p,pi⟩=1
    s = torch.dot(p, pi)
    p = p / (s + 1e-16)

    g = beta * p
    dV = V.unsqueeze(1) - V.unsqueeze(0)
    pot = K * dV * log_mean_torch(p.unsqueeze(1), p.unsqueeze(0), eps=eps)
    return K @ g - g - pot.sum(dim=1)

def fp_rhs_compact(K, p, V, beta, eps=1e-16):
    p = torch.clamp(p.to(dtype=DTYPE), min=eps)
    g = beta * torch.log(p) + V.to(dtype=DTYPE)
    return K @ g - g

class FPCompact(torch.nn.Module):
    def __init__(self, K, pi, V, beta):
        super().__init__()
        # Register as buffers so they follow .to()
        self.register_buffer('K',  K.to(device=DEVICE, dtype=DTYPE))
        self.register_buffer('pi', pi.to(device=DEVICE, dtype=DTYPE))
        self.register_buffer('V',  torch.as_tensor(V, device=DEVICE, dtype=DTYPE))
        self.beta = float(beta)

    def forward(self, t, p):
        return heat_equation_logmean_rhs(self.K, self.pi, p, self.V, self.beta)

# ---------- Graph builders ----------
def build_K_pi_from_delaunay(
    n: int,
    points: np.ndarray = None,
    jitter: float = 0.05,
    weight_sigma: float = None,
    seed: int = 0,
):
    rng = np.random.default_rng(seed)
    if points is None:
        m = int(np.round(np.sqrt(n)))
        k = int(np.ceil(n / m))
        xs = np.linspace(0, 1, m)
        ys = np.linspace(0, 1, k)
        X, Y = np.meshgrid(xs, ys, indexing='xy')
        pts = np.stack([X.ravel(), Y.ravel()], axis=1)[:n]
        pts = pts + rng.normal(scale=jitter, size=pts.shape)
        pts = np.clip(pts, 0.0, 1.0)
    else:
        assert points.shape == (n, 2)
        pts = points
    tri = Delaunay(pts)
    edges = set()
    for simplex in tri.simplices:
        i, j, k2 = simplex
        edges.add(tuple(sorted((i,j))))
        edges.add(tuple(sorted((j,k2))))
        edges.add(tuple(sorted((k2,i))))
    edges = list(edges)
    dists = []
    for (i,j) in edges:
        d = np.linalg.norm(pts[i] - pts[j]); dists.append(d)
    if len(dists) == 0:
        raise ValueError("No edges in Delaunay — check n or points")
    mean_edge = float(np.mean(dists))
    if weight_sigma is None:
        weight_sigma = mean_edge
    W = np.zeros((n, n), dtype=float)
    for (i,j), d in zip(edges, dists):
        w = np.exp(- (d ** 2) / (2 * (weight_sigma ** 2)))
        W[i, j] = w; W[j, i] = w
    strengths = W.sum(axis=1)
    isolated = np.where(strengths == 0)[0]
    if isolated.size > 0:
        for i in isolated:
            ds = np.linalg.norm(pts - pts[i], axis=1); ds[i] = np.inf
            j = int(np.argmin(ds))
            w = np.exp(- (ds[j] ** 2) / (2 * (weight_sigma ** 2)))
            W[i, j] = w; W[j, i] = w
        strengths = W.sum(axis=1)
    K = W / strengths[:, None]
    pi = strengths / np.sum(strengths)
    return K, pi

def build_K_pi_from_delaunay_torch(
    n: int,
    points: np.ndarray = None,
    jitter: float = 0.05,
    weight_sigma: float = None,
    seed: int = 0,
    dtype=DTYPE,
    device=DEVICE,
):
    rng = np.random.default_rng(seed)
    if points is None:
        m = int(np.round(np.sqrt(n)))
        k = int(np.ceil(n / m))
        xs = np.linspace(0, 1, m)
        ys = np.linspace(0, 1, k)
        X, Y = np.meshgrid(xs, ys, indexing='xy')
        pts = np.stack([X.ravel(), Y.ravel()], axis=1)[:n]
        pts = pts + rng.normal(scale=jitter, size=pts.shape)
        pts = np.clip(pts, 0.0, 1.0)
    else:
        assert isinstance(points, np.ndarray)
        assert points.shape == (n, 2)
        pts = points.copy()
    tri = Delaunay(pts)
    edges = set()
    for simplex in tri.simplices:
        i, j, k2 = simplex
        edges.add(tuple(sorted((i, j))))
        edges.add(tuple(sorted((j, k2))))
        edges.add(tuple(sorted((k2, i))))
    edges = list(edges)
    dists = []
    for (i, j) in edges:
        d = np.linalg.norm(pts[i] - pts[j]); dists.append(d)
    if len(dists) == 0:
        raise ValueError("No edges in Delaunay — check n or points")
    mean_edge = float(np.mean(dists))
    if weight_sigma is None:
        weight_sigma = mean_edge
    W_np = np.zeros((n, n), dtype=float)
    for (i, j), d in zip(edges, dists):
        w = np.exp(- (d ** 2) / (2 * (weight_sigma ** 2)))
        W_np[i, j] = w; W_np[j, i] = w
    strengths_np = W_np.sum(axis=1)
    isolated = np.where(strengths_np == 0)[0]
    if isolated.size > 0:
        for i in isolated:
            ds = np.linalg.norm(pts - pts[i], axis=1); ds[i] = np.inf
            j = int(np.argmin(ds))
            w = np.exp(- (ds[j] ** 2) / (2 * (weight_sigma ** 2)))
            W_np[i, j] = w; W_np[j, i] = w
        strengths_np = W_np.sum(axis=1)
    K_np = W_np / strengths_np[:, None]
    pi_np = strengths_np / np.sum(strengths_np)
    K_torch = torch.from_numpy(K_np).to(dtype=dtype, device=device)
    pi_torch = torch.from_numpy(pi_np).to(dtype=dtype, device=device)
    return K_torch, pi_torch

def generate_K(n, mode='complete'):
    if mode == 'complete':
        K = torch.ones((n,n), dtype=DTYPE, device=DEVICE) / n
        pi = torch.ones(n, dtype=DTYPE, device=DEVICE) / n
    elif mode == 'delauney':
        K, pi = build_K_pi_from_delaunay_torch(n, device=DEVICE, dtype=DTYPE)
    else:
        raise NotImplementedError
    return K, pi

def initialize_p0(n, pi_np, mode='random', seed=42, maxp0=100.0):
    # pi_np is expected to be a NumPy array here (call with pi.detach().cpu().numpy()).
    if mode == 'random':
        rng = np.random.default_rng(seed)
        p0 = rng.uniform(0, maxp0, n)
        s = np.dot(p0, pi_np)
        p0 = p0 / s
    else:
        raise NotImplementedError
    return p0

def initialize_V(n, mode='random', vrange=150.0, seed=42):
    rng = np.random.default_rng(seed)
    if mode == 'random':
        V = rng.uniform(-vrange, vrange, n).astype(float)
    elif mode == 'zeros':
        V = np.zeros(n, dtype=float)
    else:
        raise NotImplementedError
    return V

def initialize_V_stable(n, mode='random', vrange=2.0, seed=42, tau=0.2, pi=None, K=None, smooth_steps=0):
    rng = np.random.default_rng(seed)
    if mode == 'random':
        V = rng.uniform(-vrange, vrange, n).astype(float)
    elif mode == 'zeros':
        V = np.zeros(n, dtype=float)
    else:
        raise NotImplementedError
    if pi is not None:
        V = V - np.dot(V, pi)
    if K is not None and smooth_steps > 0:
        Knp = K.detach().cpu().numpy() if isinstance(K, torch.Tensor) else K
        for _ in range(smooth_steps):
            V = Knp @ V
    V = tau * V
    return V

# ---------- Geodesic solvers ----------
def build_undirected_edges_torch(K, pi, tol=0.0):
    K = K.to(dtype=DTYPE); pi = pi.to(dtype=DTYPE)
    n = K.shape[0]
    mask = (K > tol) | (K.T > tol)
    I, J = torch.where(mask)
    keep = I < J
    I, J = I[keep], J[keep]
    omega = 0.5 * (pi[I] * K[I, J] + pi[J] * K[J, I])
    return I.long(), J.long(), omega

def tiny_psi_diff_solver_torch(K, pi, rho, drho, pin=0, eps=1e-16):
    device = K.device
    K = K.to(dtype=DTYPE); pi = pi.to(dtype=DTYPE); rho = rho.to(dtype=DTYPE); drho = drho.to(dtype=DTYPE)
    I, J, omega = build_undirected_edges_torch(K, pi)
    m = I.numel()
    theta = log_mean_torch(rho[I], rho[J], eps=eps)
    D_diag = omega * theta
    n = K.shape[0]
    B = torch.zeros((n, m), dtype=DTYPE, device=device)
    cols = torch.arange(m, device=device)
    B[I, cols] =  1.0
    B[J, cols] = -1.0
    b = -(pi * drho)
    if b.abs().sum() < 1e-10:
        psi = torch.zeros_like(pi)
        psi_diff_edges = torch.zeros_like(D_diag)
        psi_diff_full = psi[None, :] - psi[:, None]
        metric_norm_sq = torch.tensor(0.0, dtype=DTYPE, device=device)
        edges = dict(I=I, J=J, omega=omega, theta=theta, D_diag=D_diag)
        return psi, psi_diff_edges, psi_diff_full, metric_norm_sq, edges
    if abs(b.sum().item()) > 1e-4:
        raise ValueError("Mass conservation violated: sum(π ⊙ δρ) must be 0.")
    L = B @ (D_diag.unsqueeze(0) * B.T)
    keep_idx = torch.tensor([i for i in range(n) if i != pin], device=device)
    L_red = L.index_select(0, keep_idx).index_select(1, keep_idx)
    b_red = b.index_select(0, keep_idx)
    psi = torch.zeros(n, dtype=DTYPE, device=device)
    psi_red = torch.linalg.solve(L_red, b_red)
    psi[keep_idx] = psi_red
    psi_diff_edges = psi[J] - psi[I]
    psi_diff_full  = psi.unsqueeze(0) - psi.unsqueeze(1)
    metric_norm_sq = 0.5 * (psi @ b)
    edges = dict(I=I, J=J, omega=omega, theta=theta, D_diag=D_diag)
    return psi, psi_diff_edges, psi_diff_full, metric_norm_sq, edges

def solve_eq_cqp_cholesky(M, A, c, *, check_pd=True, precond_eps=1e-3):
    dtype = DTYPE
    device = M.device
    M = M.to(dtype); A = A.to(dtype); c = c.to(dtype)
    n = M.shape[0]; m = A.shape[0]
    M = 0.5 * (M + M.T) + precond_eps * torch.eye(n, dtype=dtype, device=device)
    L_M = torch.linalg.cholesky(M)
    Y = torch.cholesky_solve(A.T, L_M, upper=False)
    S = A @ Y
    S = 0.5 * (S + S.T) + precond_eps * torch.eye(m, dtype=dtype, device=device)
    rhs = -2.0 * c
    L_S = torch.linalg.cholesky(S)
    lam = torch.cholesky_solve(rhs.unsqueeze(-1), L_S, upper=False).squeeze(-1)
    x_star = -0.5 * (Y @ lam)
    r_primal = A @ x_star - c
    r_stat   = 2.0 * (M @ x_star) + A.T @ lam
    info = {
        "primal_res_norm": torch.linalg.norm(r_primal).item(),
        "stationarity_res_norm": torch.linalg.norm(r_stat).item(),
        "schur_cond_hint": torch.linalg.cond(S).item() if m > 0 else 0.0
    }
    return x_star, lam, info

def cholesky_solver_torch(K, pi, rho, drho, pin=0, eps=1e-16, precond_eps=1e-12):
    device = K.device
    n = pi.shape[0]
    theta = log_mean_torch(rho.unsqueeze(1), rho.unsqueeze(0), eps=eps)
    Z = K * theta
    W = Z * pi.reshape(n,1)
    D = torch.diag(W.sum(dim=1)) + precond_eps * torch.eye(n, dtype=DTYPE, device=device)
    M = D - W
    Abase = Z - torch.diag(Z.sum(dim=1))
    e = torch.zeros(n, dtype=DTYPE, device=device); e[pin] = 1.0
    A = torch.vstack([Abase, e])
    c = torch.hstack([drho, torch.zeros(1, dtype=DTYPE, device=device)])
    psi, lam, info = solve_eq_cqp_cholesky(M, A, c, check_pd=True)
    return psi, lam, info

def approx_ctmc_step_probs(K, pi, V, beta, rho, delta_t: float, eps=1e-8):
    n = K.shape[0]
    theta = log_mean_torch(rho.unsqueeze(1), rho.unsqueeze(0), eps=eps)
    Z = K * theta * pi.reshape(n,1)
    mu = rho * pi
    Phi = V + beta * torch.log(torch.clamp(rho, min=eps))
    posflux = torch.clamp(Phi.unsqueeze(1) - Phi.unsqueeze(0), min=0.0)
    Q = Z * posflux / mu.reshape(n,1)
    Q = Q - torch.diag(Q.sum(dim=1))
    P = torch.eye(Q.size(0), dtype=Q.dtype, device=Q.device) + delta_t * Q
    P = torch.clamp(P, min=0.0)
    P = P / (P.sum(dim=1, keepdim=True) + eps)
    return P

def project_pi_simplex(p, pi, eps_floor=1e-5):
    p = torch.clamp(p, min=eps_floor).to(pi.dtype)
    s = torch.dot(p, pi)
    return p / (s + 1e-16)


def compute_geodesic(model, K, p0, t_span, beta=None, device=None):
    """
    Compute a geodesic trajectory using a trained model or ground-truth dynamics.
    
    Args:
        model: Neural network (VAndBetaMLP) or None (use ground-truth V, beta)
        K: (n, n) transition kernel
        p0: (n,) initial distribution
        t_span: time points to evaluate at
        beta: entropy coefficient (if None, use model's learned beta)
        device: torch device
    
    Returns:
        trajectory: (T, n) tensor of probability distributions over time
    """
    if device is None:
        device = DEVICE
    
    # Get V and beta from model
    if model is not None:
        model.eval()
        with torch.no_grad():
            V_all, beta_learned = model()
        V = V_all.to(device=device, dtype=DTYPE)
        if beta is None:
            beta = float(beta_learned.item())
    else:
        raise ValueError("Must provide a model")
    
    # Get stationary distribution
    K = K.to(device=device, dtype=DTYPE)
    pi = torch.ones(K.shape[0], device=device, dtype=DTYPE) / K.shape[0]  # Uniform for now
    
    # Setup solver
    p0 = p0.to(device=device, dtype=DTYPE)
    solver = FPCompact(K, pi, V, beta).to(device)
    t_span = t_span.to(device=device, dtype=DTYPE)
    
    # Solve ODE
    trajectory = odeint(solver, p0, t_span, method='euler')
    
    return trajectory


