import numpy as np
import random
import torch
from torchdiffeq import odeint
import os
import json
from scipy.spatial import Delaunay


def stationary_from_K(K: np.ndarray):
    # Computes the stationary distribution from K leveraging its charachterization as the
    # maximum-absolute-eigenvalue eigenvector.
    w, v = np.linalg.eig(K.T)
    i = np.argmin(np.abs(w - 1.0))
    pi = np.real(v[:, i])
    pi = np.maximum(pi, 0)
    pi = pi / pi.sum()
    return pi



def log_mean(a, b, eps=1e-16):
    # Λ(a, b) = (a - b) / (log a - log b), with limit = a if a ≈ b
    a = torch.clamp(a, min=eps); b = torch.clamp(b, min=eps)
    denom = torch.log(a) - torch.log(b)
    num = a - b
    return torch.where(torch.abs(denom) > 1e-12, num / denom, a)

def grad_V(V):
    # Generate the discrete gradient of the potenital V
    try:
        V = torch.from_numpy(V).float()
    except:
        V = V.float()
    return V.unsqueeze(1) - V.unsqueeze(0)  # [n,n] with [i,j] = V[j]-V[i]



def fp_rhs_logmean(K, p, V, beta, eps=1e-16):
    """
    Graph Fokker–Planck RHS with log-mean mobility:
    sum_y K[x,y] * Logmean(p[x], p[y]) * [ (V[y]-V[x]) + beta*(log p[y] - log p[x]) ]
    """
    p = torch.clamp(p, min=eps)
    logp = torch.log(p)
    
    # Edge-level: broadcast over rows
    G = beta * logp + V  # shape [n]
    # Compute differences
    DeltaG = G.unsqueeze(1) - G.unsqueeze(0)      # [n, n]
    theta = log_mean(p.unsqueeze(1), p.unsqueeze(0), eps=eps)  # [n, n]
    
    # Weighted flux per edge
    flux = K * theta * DeltaG  # [n, n]

    # Sum over neighbors
    return flux.sum(dim=1)


def heat_equation_rhs(K, p, V, beta, eps=1e-16):
    # K: [n,n] row-stochastic; p,V: [n]
    p = torch.clamp(p, min=eps)
    g = beta * p + V
    return K @ g - g  # shape [n]

def heat_equation_logmean_rhs(K, p, V, beta, eps=1e-16):
    # K: [n,n] row-stochastic; p,V: [n]
    K = K.float()
    p = torch.clamp(p, min=eps)
    g = beta * p
    pot = K * grad_V(V) * log_mean(p.unsqueeze(1), p.unsqueeze(0), eps=eps)
    return K @ g - g - pot.sum(dim=1) # shape [n]


def fp_rhs_compact(K, p, V, beta, eps=1e-16):
    # K: [n,n] row-stochastic; p,V: [n]
    p = torch.clamp(p, min=eps)
    g = beta * torch.log(p) + V
    return K @ g - g  # shape [n]


class FPCompact(torch.nn.Module):
    def __init__(self, K, V, beta): 
        super().__init__()
        # self.register_buffer('K', K)
        # self.register_buffer('V', V)
        self.K = K
        self.V = V
        self.beta = float(beta)
    def forward(self, t, p):
        # dp = fp_rhs_compact(self.K, p, self.V, self.beta)
        # dp = fp_rhs_logmean(self.K, p, self.V, self.beta)
        # dp = heat_equation_rhs(self.K, p, self.V, self.beta)
        dp = heat_equation_logmean_rhs(self.K, p, self.V, self.beta)
        return dp

def build_K_pi_from_delaunay(
    n: int,
    points: np.ndarray = None,
    jitter: float = 0.05,
    weight_sigma: float = None,
    seed: int = 0,
):
    """
    Generates:
      - K: (n,n) row-stochastic reversible Markov kernel
      - pi: (n,) stationary distribution
    using a Delaunay triangulation of points in [0,1]^2 (with jitter if points are grid)

    Args:
      n: number of nodes
      points: optional (n,2) array of 2D coords; if None, uses regular grid + jitter
      jitter: scale of Gaussian noise to add to grid positions (if points is None)
      weight_sigma: standard deviation for distance-based weight; if None, set proportional to mean edge length
      seed: RNG seed
    Returns:
      K: numpy.ndarray of shape (n,n), K[i,j] weight from i→j, rows sum to 1
      pi: numpy.ndarray of shape (n,), stationary distribution (π_i)
    """
    rng = np.random.default_rng(seed)

    # 1) generate or accept points
    if points is None:
        # make a grid, jittered
        m = int(np.round(np.sqrt(n)))
        k = int(np.ceil(n / m))
        xs = np.linspace(0, 1, m)
        ys = np.linspace(0, 1, k)
        X, Y = np.meshgrid(xs, ys, indexing='xy')
        pts = np.stack([X.ravel(), Y.ravel()], axis=1)[:n]
        pts = pts + rng.normal(scale=jitter, size=pts.shape)
        pts = np.clip(pts, 0.0, 1.0)
    else:
        assert points.shape == (n, 2)
        pts = points

    # 2) Delaunay to get adjacency
    tri = Delaunay(pts)
    edges = set()
    for simplex in tri.simplices:
        i, j, k2 = simplex
        edges.add(tuple(sorted((i,j))))
        edges.add(tuple(sorted((j,k2))))
        edges.add(tuple(sorted((k2,i))))
    edges = list(edges)

    # 3) compute symmetric weights W[i,j] = exp(−d(i,j)^2 / (2 σ^2))
    # choose σ if not given
    dists = []
    for (i,j) in edges:
        d = np.linalg.norm(pts[i] - pts[j])
        dists.append(d)
    if len(dists) == 0:
        raise ValueError("No edges in Delaunay — check n or points")
    mean_edge = float(np.mean(dists))
    if weight_sigma is None:
        weight_sigma = mean_edge

    W = np.zeros((n, n), dtype=float)
    for (i,j), d in zip(edges, dists):
        w = np.exp(- (d ** 2) / (2 * (weight_sigma ** 2)))
        W[i, j] = w
        W[j, i] = w

    # 4) handle isolated nodes
    strengths = W.sum(axis=1)
    isolated = np.where(strengths == 0)[0]
    if isolated.size > 0:
        # connect each isolated node to its nearest neighbor
        for i in isolated:
            # compute distances
            ds = np.linalg.norm(pts - pts[i], axis=1)
            ds[i] = np.inf
            j = int(np.argmin(ds))
            w = np.exp(- (ds[j] ** 2) / (2 * (weight_sigma ** 2)))
            W[i, j] = w
            W[j, i] = w
        strengths = W.sum(axis=1)

    # 5) build K and pi
    K = W / strengths[:, None]        # row-stochastic
    pi = strengths / np.sum(strengths) # stationary distribution

    return K, pi

def build_K_pi_from_delaunay_torch(
    n: int,
    points: np.ndarray = None,
    jitter: float = 0.05,
    weight_sigma: float = None,
    seed: int = 0,
    dtype=torch.float64,
    device=torch.device("cpu"),
):
    """
    Similar to your previous function, but returns torch.Tensor K, pi on given device.

    Returns:
      K_torch : (n,n) torch tensor, row-stochastic kernel
      pi_torch: (n,)   torch tensor, stationary distribution
    """
    rng = np.random.default_rng(seed)

    # 1) generate or accept points
    if points is None:
        m = int(np.round(np.sqrt(n)))
        k = int(np.ceil(n / m))
        xs = np.linspace(0, 1, m)
        ys = np.linspace(0, 1, k)
        X, Y = np.meshgrid(xs, ys, indexing='xy')
        pts = np.stack([X.ravel(), Y.ravel()], axis=1)[:n]
        pts = pts + rng.normal(scale=jitter, size=pts.shape)
        pts = np.clip(pts, 0.0, 1.0)
    else:
        assert isinstance(points, np.ndarray)
        assert points.shape == (n, 2)
        pts = points.copy()

    # 2) Delaunay to get adjacency
    tri = Delaunay(pts)
    edges = set()
    for simplex in tri.simplices:
        i, j, k2 = simplex
        edges.add(tuple(sorted((i, j))))
        edges.add(tuple(sorted((j, k2))))
        edges.add(tuple(sorted((k2, i))))
    edges = list(edges)

    # 3) compute distances
    dists = []
    for (i, j) in edges:
        d = np.linalg.norm(pts[i] - pts[j])
        dists.append(d)
    if len(dists) == 0:
        raise ValueError("No edges in Delaunay — check n or points")

    mean_edge = float(np.mean(dists))
    if weight_sigma is None:
        weight_sigma = mean_edge

    # Build symmetric W in numpy
    W_np = np.zeros((n, n), dtype=float)
    for (i, j), d in zip(edges, dists):
        w = np.exp(- (d ** 2) / (2 * (weight_sigma ** 2)))
        W_np[i, j] = w
        W_np[j, i] = w

    # 4) handle isolated nodes (in numpy)
    strengths_np = W_np.sum(axis=1)
    isolated = np.where(strengths_np == 0)[0]
    if isolated.size > 0:
        for i in isolated:
            ds = np.linalg.norm(pts - pts[i], axis=1)
            ds[i] = np.inf
            j = int(np.argmin(ds))
            w = np.exp(- (ds[j] ** 2) / (2 * (weight_sigma ** 2)))
            W_np[i, j] = w
            W_np[j, i] = w
        strengths_np = W_np.sum(axis=1)

    # 5) build K and pi in numpy
    K_np = W_np / strengths_np[:, None]
    pi_np = strengths_np / np.sum(strengths_np)

    # 6) convert to torch
    K_torch = torch.from_numpy(K_np).to(dtype=dtype, device=device)
    pi_torch = torch.from_numpy(pi_np).to(dtype=dtype, device=device)

    return K_torch, pi_torch

def generate_K(n, mode='complete', seed=42, **kwargs):
    """
    Generate graph transition matrix and stationary distribution.
    
    Supported modes:
    - complete: fully connected
    - delaunay: Delaunay triangulation
    - erdos-renyi: random graph with edge probability p
    - d-regular: regular graph with degree d
    - watts-strogatz: small-world
    - sbm: stochastic block model
    - euclid-mst: Euclidean minimum spanning tree
    - k-partite: k-partite graph
    - grid: 2D grid
    - torus: 2D torus (periodic grid)
    - hypercube: hypercube graph
    - apollonian: Apollonian network
    
    Args:
        n: number of nodes
        mode: graph type
        seed: random seed
        **kwargs: additional parameters for specific graph types
    """
    from .graph_generators import generate_graph
    
    K, pi, info = generate_graph(mode, n, seed=seed, device='cpu', dtype=torch.float64, **kwargs)
    return K, pi

def initialize_p0(n, pi, mode = 'random', seed=42, maxp0 = 100.0):
    if mode == 'random':
        p0 = np.random.uniform(0, maxp0, n)
        s = np.dot(p0, pi)
        p0 = p0 / s
    else: raise NotImplementedError

    return p0

def initialize_V(n, mode = 'random', vrange = 150.0, seed=42, ):
    if mode == 'random':
        V = np.random.uniform(-vrange, vrange, n)
    elif mode == 'zeros':
        V = np.zeros(n)
    else: raise NotImplementedError

    return V

def log_mean_torch(a, b, eps=1e-16):
    a = torch.clamp(a, min=eps)
    b = torch.clamp(b, min=eps)
    den = torch.log(a) - torch.log(b)
    out = torch.where(den.abs() > 1e-12, (a - b) / den, a)  # θ(a,a)=a
    return out

def build_undirected_edges_torch(K, pi, tol=0.0):
    """
    Inputs:
      K:  (n,n) row-stochastic (or rates) tensor, reversible wrt pi
      pi: (n,) stationary distribution tensor
    Returns:
      I, J: edge index tensors with orientation i<j
      omega: symmetric edge weight ω_ij = 0.5*(π_i K_ij + π_j K_ji)
    """
    K = K.double(); pi = pi.double()
    n = K.shape[0]
    mask = (K > tol) | (K.T > tol)
    I, J = torch.where(mask)
    keep = I < J
    I, J = I[keep], J[keep]
    omega = 0.5 * (pi[I] * K[I, J] + pi[J] * K[J, I])
    return I.long(), J.long(), omega

def tiny_psi_diff_solver_torch(K, pi, rho, drho, pin=0, eps=1e-16):
    """
    PyTorch solver for the local Maas problem (logarithmic mean):
      (B D B^T) ψ = -δμ,  with  D_e = ω_ij * θ(ρ_i,ρ_j),  δμ = π ⊙ δρ
    Inputs:
      K   : (n,n) tensor, reversible wrt pi
      pi  : (n,) tensor (stationary law)
      rho : (n,) tensor (π-density, ⟨rho,pi⟩=1)
      drho: (n,) tensor (desired time-derivative, must satisfy ⟨drho,pi⟩=0)
      pin : node to pin ψ=0 (gauge fix)
    Returns:
      psi            : (n,) tensor (ψ_pin=0)
      psi_diff_edges : (m,) tensor for oriented edges (i<j): ψ_j - ψ_i
      psi_diff_full  : (n,n) tensor of all pairwise ψ(y)-ψ(x)
      metric_norm_sq : scalar (0.5 * ψ^T * (-δμ))
      edges          : dict with I,J,omega,theta,D_diag
    """
    device = K.device
    K = K.double(); pi = pi.double(); rho = rho.double(); drho = drho.double()

    # Build undirected edges and symmetric edge weights ω
    I, J, omega = build_undirected_edges_torch(K, pi)
    m = I.numel()

    # D_e = ω_ij * θ(ρ_i, ρ_j)  (Maas weights with log-mean)
    theta = log_mean_torch(rho[I], rho[J], eps=eps)
    D_diag = omega * theta  # (m,)

    # Incidence matrix B (dense build is fine for small/medium graphs)
    n = K.shape[0]
    B = torch.zeros((n, m), dtype=torch.double, device=device)
    cols = torch.arange(m, device=device)
    B[I, cols] =  1.0
    B[J, cols] = -1.0

    # RHS b = -δμ = -(π ⊙ δρ); must sum to 0 for consistency
    b = -(pi * drho)
    if b.abs().sum() < 1e-10:
        # nothing to do
        psi = torch.zeros_like(pi)
        psi_diff_edges = torch.zeros_like(D_diag)
        psi_diff_full = psi[None, :] - psi[:, None]
        metric_norm_sq = torch.tensor(0.0, dtype=torch.double, device=device)
        edges = dict(I=I, J=J, omega=omega, theta=theta, D_diag=D_diag)
        return psi, psi_diff_edges, psi_diff_full, metric_norm_sq, edges

    if abs(b.sum().item()) > 1e-4:
        raise ValueError("Mass conservation violated: sum(π ⊙ δρ) must be 0.")

    # Maas Laplacian L = B D B^T, pin one node to remove nullspace
    L = B @ (D_diag.unsqueeze(0) * B.T)  # (n,n), SPD on subspace ⟂ 1
    keep = torch.tensor([i for i in range(n) if i != pin], device=device)
    L_red = L.index_select(0, keep).index_select(1, keep)
    b_red = b.index_select(0, keep)

    # Solve for ψ on the reduced system; set ψ_pin = 0
    psi = torch.zeros(n, dtype=torch.double, device=device)
    psi_red = torch.linalg.solve(L_red, b_red)
    psi[keep] = psi_red  # ψ[pin]=0

    # Edge potential differences (oriented i<j) and full pairwise matrix
    psi_diff_edges = psi[J] - psi[I]               # (m,)
    psi_diff_full  = psi.unsqueeze(0) - psi.unsqueeze(1)  # (n,n): ψ(y)-ψ(x)

    # Instantaneous squared norm (action): 0.5 * ψ^T * (-δμ) = 0.5 * ψ^T b
    metric_norm_sq = 0.5 * (psi @ b)

    edges = dict(I=I, J=J, omega=omega, theta=theta, D_diag=D_diag)
    return psi, psi_diff_edges, psi_diff_full, metric_norm_sq, edges

def cholesky_solver_torch(K, pi, rho, drho, pin=0, eps=1e-16, precond_eps=1e-12):
    """
    Wrapper for the cholesky solver for the discrete graph geodesic.
    It builds the matrices M, A and c and then calls a standard 
    Cholesky solver for the associated quadratic optimization problem.

    We make the gauge choice psi[0] = 0 by adding an equality constraint.
    """

    n = pi.shape[0]
    theta = log_mean(rho.unsqueeze(1), rho.unsqueeze(0), eps=eps)

    Z = K * theta

    W = Z * pi.reshape(n,1)
    D = torch.diag(W.sum(dim=1)) + precond_eps * torch.eye(n)  # add small ridge for numerical stability
    M = D - W  # Laplacian matrix

    # print('M :', M)

    Abase = Z - torch.diag(Z.sum(dim=1))  # A ρ = 0

    e = torch.zeros(n)
    e[pin] = 1.0

    A = torch.vstack([Abase, e])
    c = torch.hstack([drho, torch.zeros(1)])

    psi, lam, info = solve_eq_cqp_cholesky(M, A, c, check_pd=True)

    return psi, lam, info


def solve_eq_cqp_cholesky(M, A, c, *, check_pd=True, precond_eps=1e-12):
    """
    Solve:   minimize   x^T M x    subject to   A x = c
    via Schur complement with Cholesky (highest-precision path for small SPD problems).

    Inputs
    ------
    M : (n,n) tensor, assumed symmetric positive definite (SPD)
    A : (m,n) tensor, full row rank
    c : (m,)   tensor
    check_pd : if True, do a cheap PD check by attempting Cholesky and raising on failure

    Returns
    -------
    x_star : (n,) tensor   optimal primal variable
    lam    : (m,) tensor   Lagrange multipliers (for Ax=c)
    info   : dict          small diagnostics (residuals, etc.)
    """

    # ---- 0) Set dtype for accuracy (float64) and symmetrize M for safety
    dtype = torch.float64
    M = M.to(dtype)
    A = A.to(dtype)
    c = c.to(dtype)

    # Symmetrize to damp tiny asymmetries from fp roundoff
    M = 0.5 * (M + M.T)

    n = M.shape[0]
    m = A.shape[0]

    # ---- 1) Cholesky factorization of M (M = L_M L_M^T), raises if not SPD
    # This gives us a stable way to apply M^{-1} via triangular solves
    L_M = torch.linalg.cholesky(M)  # lower-triangular

    # ---- 2) Form Y = M^{-1} A^T by solving M Y = A^T with the Cholesky of M
    # Use cholesky_solve to avoid explicit inverse: Y solves M Y = A^T
    Y = torch.cholesky_solve(A.T, L_M, upper=False)  # shape (n, m)

    # ---- 3) Build the Schur complement S = A M^{-1} A^T (m x m)
    # This matrix is SPD under standard assumptions (A full row rank, M SPD)
    S = A @ Y
    # Symmetrize S for numerical hygiene
    S = 0.5 * (S + S.T) + precond_eps * torch.eye(m, dtype=dtype)

    # ---- 4) Solve the SPD system S λ = -2 c via Cholesky
    # (Because objective is x^T M x, stationarity is 2 M x + A^T λ = 0)
    rhs = -2.0 * c
    L_S = torch.linalg.cholesky(S)
    lam = torch.cholesky_solve(rhs.unsqueeze(-1), L_S, upper=False).squeeze(-1)  # (m,)

    # ---- 5) Back-substitute to get x* = -1/2 * M^{-1} A^T λ  =  -1/2 * Y λ
    x_star = -0.5 * (Y @ lam)  # (n,)

    # ---- 6) Diagnostics: primal feasibility & stationarity residuals
    r_primal = A @ x_star - c                                   # should be ~0
    r_stat   = 2.0 * (M @ x_star) + A.T @ lam                   # should be ~0
    info = {
        "primal_res_norm": torch.linalg.norm(r_primal).item(),
        "stationarity_res_norm": torch.linalg.norm(r_stat).item(),
        "schur_cond_hint": torch.linalg.cond(S).item() if m > 0 else 0.0
    }

    return x_star, lam, info


def approx_ctmc_step_probs(K, pi, V, beta, rho, delta_t: float, eps=1e-8):
    """
    Approximate discrete transition probabilities from rates Q via small-step linearization.
    Q: (n,n) generator (rows sum to zero)
    delta_t: small time step
    """

    n = K.shape[0]
    theta = log_mean_torch(rho.unsqueeze(1), rho.unsqueeze(0), eps=eps)
    Z = K * theta * pi.reshape(n,1)
    mu = rho * pi

    Phi = V + beta * torch.log(torch.clamp(rho, min=eps))
    posflux = torch.clamp(Phi.unsqueeze(1) - Phi.unsqueeze(0), min=0.0)

    Q = Z * posflux / mu.reshape(n,1)  # off-diagonal rates 
    Q = Q - torch.diag(Q.sum(dim=1))   # set diagonals so rows sum to zero

    P = torch.eye(Q.size(0), dtype=Q.dtype, device=Q.device) + delta_t * Q
    P = torch.clamp(P, min=0.0)  # avoid negatives
    P = P / (P.sum(dim=1, keepdim=True) + eps)  # row-normalize

    # probs = P[current_state]           # transition probabilities
    # next_state = torch.multinomial(probs, num_samples=1)

    return P


def project_pi_simplex(p, pi, eps_floor=1e-5):
    p = torch.clamp(p, min=eps_floor).to(pi.dtype)
    s = torch.dot(p, pi)
    return p / (s + 1e-16)


def generate_synthetic_data(config):
    """
    Generate synthetic trajectory data from configuration.
    
    Args:
        config: dict with keys:
            - n_nodes: number of nodes
            - graph_type: 'delaunay' or 'complete'
            - jitter: jitter for Delaunay (default 0.05)
            - weight_sigma: weight sigma for Delaunay (default None)
            - seed: random seed
            - beta: entropy coefficient
            - V_range: range for potential V
            - V_mode: 'random' or 'zeros'
            - p0_mode: 'random'
            - max_p0: max value for p0
            - t0, t1: time range
            - num_timesteps: number of time steps
            - num_samples: number of samples per timestep
            - sample_from_gt: whether to sample from ground truth
            - output_dir: output directory
            - save_metadata: whether to save metadata
    """
    # Extract config
    n = config['n_nodes']
    graph_type = config['graph_type']
    seed = config.get('seed', 42)
    beta = config['beta']
    V_range = config['V_range']
    V_mode = config.get('V_mode', 'random')
    p0_mode = config.get('p0_mode', 'random')
    max_p0 = config.get('max_p0', 100.0)
    t0 = config['t0']
    t1 = config['t1']
    T = config['num_timesteps']
    N_samples = config['num_samples']
    sample_from_gt = config.get('sample_from_gt', True)
    output_path = config['output_dir']
    save_metadata = config.get('save_metadata', True)
    same_p0_eval = config.get('same_p0_eval', True)
    
    # Set seeds
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    
    # Generate graph
    print(f"Generating {graph_type} graph with {n} nodes...")
    K, pi = generate_K(n, mode=graph_type)
    
    # Update n to actual graph size (some graph types like torus adjust n to fit constraints)
    n = K.shape[0]
    print(f"Graph generated: n={n}, beta={beta}")
    
    # Generate potential V
    V = initialize_V(n, mode=V_mode, vrange=V_range, seed=seed)
    
    # Save graph structure
    os.makedirs(os.path.expanduser(output_path), exist_ok=True)
    import pickle
    import networkx as nx
    
    # Create a simple networkx graph for visualization
    G = nx.Graph()
    G.add_nodes_from(range(n))
    # Add edges where K[i,j] > 0
    for i in range(n):
        for j in range(i+1, n):
            if K[i, j] > 1e-6:
                G.add_edge(i, j, weight=float(K[i, j]))
    
    # Generate node positions for visualization
    if graph_type == 'complete':
        pos = nx.spring_layout(G, seed=seed)
    else:
        # For Delaunay, positions are embedded in the generation
        pos = nx.spring_layout(G, seed=seed)
    
    # Save graph data
    graph_data_path = os.path.join(os.path.expanduser(output_path), 'graph_data.pkl')
    with open(graph_data_path, 'wb') as f:
        pickle.dump((K, pi, pos, K, G), f)  # (K, pi, pos, W, G)
    print(f"Saved graph data to {graph_data_path}")
    
    # Generate trajectories for train and val
    # Generate trajectories for train and val
    for mode in ['train', 'val']:
        
        if mode == 'train':
            p0 = initialize_p0(n, pi.numpy(), mode=p0_mode, seed=seed, maxp0=max_p0)
            p0 = torch.from_numpy(p0).float()
        elif mode == 'val' and not same_p0_eval:
            p0 = initialize_p0(n, pi.numpy(), mode=p0_mode, seed=seed+1, maxp0=max_p0)
            p0 = torch.from_numpy(p0).float()
        
        print(f"\nGenerating {mode} trajectories...")
        mode_out_path = os.path.join(output_path, mode)
        os.makedirs(os.path.expanduser(mode_out_path), exist_ok=True)
        
        solver = FPCompact(K, V, beta)
        
        t = torch.linspace(t0, t1, T)
        traj = odeint(solver, p0, t, method='euler')
        
        print(f"Trajectory shape: {traj.shape}")
        
        samples = []
        vectors = []
        rho_full = []
        
        for tau in range(T-1):
            
            rho = traj[tau].double()
            rho_plus = traj[tau+1].double()
            
            DTYPE = torch.float64
            rho = project_pi_simplex(rho, pi.to(DTYPE), 1e-7)
            rho_plus = project_pi_simplex(rho_plus, pi.to(DTYPE), 1e-7)
            
            deltat = 1/T * (t1 - t0)
            
            drho = (rho - rho_plus)/deltat
            
            psi, lam, info = cholesky_solver_torch(K, pi, rho, drho, pin=0, eps=1e-6, precond_eps=1e-3)
            psi_diff_full = psi.unsqueeze(0) - psi.unsqueeze(1)
            
            if sample_from_gt:
                samples_t = torch.multinomial(pi * rho, N_samples, replacement=True)
                samples.append(samples_t)
            
            vectors.append(psi_diff_full)
            rho_full.append(rho)
        
        # Save data
        samples_pt = torch.stack(samples, dim=0) if samples else None
        vectors_pt = torch.stack(vectors, dim=0)
        rho_pt = torch.stack(rho_full, dim=0)
        
        if samples_pt is not None:
            torch.save(samples_pt, os.path.join(os.path.expanduser(mode_out_path), 'samples_tm.pt'))
        torch.save(vectors_pt, os.path.join(os.path.expanduser(mode_out_path), 'v_mat_seq.pt'))
        torch.save(rho_pt, os.path.join(os.path.expanduser(mode_out_path), 'rho_gt_seq.pt'))
        
        print(f"Saved {mode} data to {mode_out_path}")
    
    # Save ground truth data
    gt_data = {
        'V': V,
        'beta': beta,
        'trajectories': traj
    }
    gt_data_path = os.path.join(os.path.expanduser(output_path), 'gt_data.pkl')
    import pickle
    with open(gt_data_path, 'wb') as f:
        pickle.dump(gt_data, f)
    
    # Save metadata
    if save_metadata:
        metadata = {
            "n": n,
            "graph_type": graph_type,
            "beta": beta,
            "V_range": V_range,
            "t0": t0,
            "t1": t1,
            "num_timesteps": T,
            "num_samples": N_samples,
            "seed": seed,
            "pi": pi.tolist() if isinstance(pi, torch.Tensor) else pi.tolist()
        }
        metadata_path = os.path.join(os.path.expanduser(output_path), 'metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"Saved metadata to {metadata_path}")
    
    print(f"\n✅ Synthetic data generation complete!")
    print(f"Output directory: {output_path}")


if __name__ == "__main__":
    import argparse
    import yaml
    
    parser = argparse.ArgumentParser(description='Generate synthetic gradient flow data')
    parser.add_argument('--config', '-f', type=str, required=True,
                        help='Path to configuration YAML file')
    args = parser.parse_args()
    
    # Load config
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)
    
    generate_synthetic_data(config)


