import argparse
import os
import jax
import jax.numpy as jnp
import optax
import wandb
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
from flax.training import train_state
from jax import jacrev
from datetime import datetime
import hashlib
import sympy as sp
import pickle
from archs import ModifiedMlp
import torch
torch.manual_seed(0) # torch only for dataloader
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_debug_infs", True)
jax.config.update('jax_enable_x64', True)
jax.config.update("jax_default_matmul_precision", "highest")
# --- Dataset Definition (From Snippet) ---
class CSVDataset(Dataset):
    def __init__(self, dataset_path, spd_size, to_jnp=False):
        self.path = dataset_path
        # Use pandas to read; ensure path exists
        self.data_frame = pd.read_csv(dataset_path)
        self.m = spd_size
        self.to_jnp = to_jnp
        if self.to_jnp:
            self.data = jnp.array(self.data_frame.iloc[:, 14:].values, dtype=jnp.float32)

    def __getitem__(self, idx):
        # Adjusted slicing based on your snippet
        if self.to_jnp:
            sample = self.data[idx]
        else:
            sample = self.data_frame.iloc[idx, 14:].values
        sample = sample.reshape(self.m, self.m)
        Y = self.data_frame.iloc[idx, 1:14].values   
        return jnp.array(sample, dtype=jnp.float32), jnp.array(Y, dtype=jnp.float32)

    def __len__(self):
        return len(self.data_frame)
    

class CSVTestDataset(Dataset):
    def __init__(self, spd_path, y_path, spd_size, to_jnp=False):
        self.spd_path = spd_path
        self.y_path = y_path
        self.m = spd_size
        self.to_jnp = to_jnp

        # Use pandas to read; ensure path exists
        self.spd_data = pd.read_csv(spd_path)
        self.y_data = pd.read_csv(y_path)
        self.spd_data = self.spd_data.iloc[:1100]
        self.y_data = self.y_data.iloc[0:1100, 1:14]
        if self.to_jnp:
            self.spd_data = jnp.array(self.spd_data.values, dtype=jnp.float32)
            self.y_data = jnp.array(self.y_data.values, dtype=jnp.float32)


    def __getitem__(self, idx):
        # Adjusted slicing based on your snippet
        if self.to_jnp:
            sample = self.spd_data[idx]
            Y = self.y_data[idx]
        else:
            sample = self.spd_data.iloc[idx].values
            Y = self.y_data.iloc[idx].values
        sample = sample.reshape(self.m, self.m)
        return jnp.array(sample, dtype=jnp.float32), jnp.array(Y, dtype=jnp.float32)

    def __len__(self):
        return len(self.spd_data)


def jnp_collate(batch):
    xs, ys = zip(*batch)
    return jnp.stack(xs), jnp.stack(ys)

# --- Training State Wrapper ---
class TrainState(train_state.TrainState):
    # A simple wrapper to hold params and optimizer state
    rng: jax.Array

# --- Core Logic ---
parser = argparse.ArgumentParser(description="SPD DDPM Training")

# Hyperparams
parser.add_argument('--exp_name', type=str, default='test')

parser.add_argument("--dataset_path", type=str, default='SPD-DDPM/data/condition/train_data.csv', help="Path to CSV dataset")
parser.add_argument("--spd_size", type=int, default=10, help="Dimension n of SPD matrices")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--lr_schedule", type=str, default='cosine')
parser.add_argument("--seed", type=int, default=None)

parser.add_argument('--pinn_t_0', type=float, default=0.01)
parser.add_argument('--pinn_t_boundary', type=float, default=0.1)
parser.add_argument('--pinn_t_bundle', type=int, default=10)
parser.add_argument('--pinn_r_boundary', type=float, default=1.)
parser.add_argument('--pinn_r_max', type=float, default=10.)
parser.add_argument('--pinn_t_max_per_phase', type=float, nargs='+', default=[0., 1.])

parser.add_argument('--pinn_n_layer', type=int, default=4)
parser.add_argument('--pinn_n_hidden', type=int, default=256)
parser.add_argument('--pinn_n_symm_features', type=int, default=256)
parser.add_argument('--pinn_n_fourier_features', type=int, default=0)
parser.add_argument('--pinn_sigma_fourier_features', type=float, nargs='+', default=[1.0])

parser.add_argument('--pinn_sigma_min_symm_features', type=float, nargs='+', default=[1.0])
parser.add_argument('--pinn_sigma_max_symm_features', type=float, nargs='+', default=[10.0])
parser.add_argument('--pinn_rescale_factor_symm_features', type=float, default=0.)
parser.add_argument('--global_add_logG', default = False, action='store_true')
parser.add_argument('--reparam_type', default= 'none', type=str, choices=['weight_fact', 'none'])
parser.add_argument('--pinn_boundary_func', type=str, default='log_G', choices=['log_G', 'S0'])
parser.add_argument('--pinn_path', type=str, default='experiment-spd-pinn/exp62-1_20251222_190053/model_final.pkl')
# Model Params
parser.add_argument('--model_type', type=str, default='spd_mlp')
parser.add_argument("--hidden_dim", type=int, default=256)
parser.add_argument("--num_layers", type=int, default=4)
parser.add_argument("--n_fourier", type=int, default=32)
parser.add_argument("--fourier_scale", type=float, default=10.0)
parser.add_argument("--fourier_log_scale", default = False, action='store_true')
parser.add_argument("--log_t", default = False, action='store_true')
parser.add_argument("--spd_normalize", default = False, action='store_true')
parser.add_argument("--t_max_train", type=float, default=1.0)
parser.add_argument("--t_min_train", type=float, default=1e-3)
parser.add_argument("--t_max_sample", type=float, default=None)
parser.add_argument("--t_min_sample", type=float, default=None)
parser.add_argument("--loss_weight", type=str, default='uniform', choices=['uniform', 'sigma', 'inv_sigma'])
parser.add_argument("--gradient_clip", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--loss_type", type=str, default='at_x0', choices=['at_x0', 'at_xt', 'at_I'])

parser.add_argument("--clip_min_eigval", type=float, default=1e-3, help="Minimum eigenvalue for SPD clipping")
parser.add_argument("--clip_max_norm", type=float, default=10.0, help="Maximum log-eigenvalue norm for SPD clipping")
parser.add_argument("--bfunc_only", default = False, action='store_true')
parser.add_argument("--pulled_score", default = False, action='store_true')
parser.add_argument("--sample_t_method", type=str, default='log_uniform')

# System

parser.add_argument("--save_interval", type=int, default=10)
parser.add_argument("--sample_interval", type=int, default=25)


args = parser.parse_args()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.exp_name = f"{args.exp_name}_{timestamp}"
if args.seed is None:
    args.seed = int(hashlib.sha256(args.exp_name.encode('utf-8')).hexdigest(), 16) % (2**31 - 1)
n = args.spd_size

n_layer = args.pinn_n_layer
n_hidden = args.pinn_n_hidden
n_symm_features = args.pinn_n_symm_features
sigma_min_symm_features = args.pinn_sigma_min_symm_features
sigma_max_symm_features = args.pinn_sigma_max_symm_features
rescale_factor_symm_features = args.pinn_rescale_factor_symm_features
# global_multiplier = 4 * n ** 2
global_multiplier = 1.
reparam_type = 'none'

n_fourier_features = args.pinn_n_fourier_features
sigma_fourier_features = args.pinn_sigma_fourier_features

t_bundle = args.pinn_t_bundle
t_0 = args.pinn_t_0
t_boundary = args.pinn_t_boundary
t_max = args.pinn_t_max_per_phase[-1]
r_boundary = args.pinn_r_boundary
r_max = args.pinn_r_max
if args.t_max_sample is None:
    args.t_max_sample = args.t_max_train
if args.t_min_sample is None:
    args.t_min_sample = args.t_min_train

pinn_path = args.pinn_path

project_name = "diffusion-spd-pinn"
exp_dir = os.path.join(f"experiment-{project_name}", args.exp_name)
os.makedirs(exp_dir, exist_ok=True)

# lr = args.lr


def log_G_func(t, r):
    const = -n*(n+1) / 4 * jnp.log(4 * jnp.pi * t)
    # const = jnp.log(4 * jnp.pi * t)
    log_G = const - jnp.sum(r**2,axis=-1, keepdims=True) / (4 * t)
    return log_G

def log_G_func_cat(tr):
    t = tr[...,:1]
    r = tr[...,1:]
    return log_G_func(t, r)


x = sp.symbols('x')
log_sinh_by_x_func = sp.log(sp.sinh(x)) - sp.log(x)
log_sinh_by_x_func = log_sinh_by_x_func.series(x, 0, 10).removeO().evalf()
log_sinh_by_x_func = sp.lambdify(x, log_sinh_by_x_func, 'jax')

def log_D_func(t, r):
    # Make r batched: (n,) -> (1, n)
    r_b = r[None, :] if r.ndim == 1 else r
    b, n = r_b.shape

    r_ij = 0.5 * (r_b[:, :, None] - r_b[:, None, :])  # (b, n, n)
    mask = jnp.triu(jnp.ones((n, n), dtype=r_b.dtype), k=1)[None]
    log_D_b = jnp.sum(mask * log_sinh_by_x_func(r_ij), axis=(1, 2))  # (b,)

    # Broadcast to batch size, then reshape to match t's shape
    log_D_b = jnp.broadcast_to(log_D_b, (b,))
    return jnp.reshape(log_D_b, t.shape) if t.ndim != 0 else log_D_b.reshape(())


f_S0_func = lambda t, r: log_G_func(t, r) - 0.5 * log_D_func(t, r)

# f_S0_func = lambda t, r: log_G_func(t, r) 

MODEL = ModifiedMlp
# MODEL = MlpBlock
if reparam_type == 'weight_fact':
    reparam = {
        'type': 'weight_fact',
        'mean': 1.,
        'stddev': 0.1,
    }
else:
    reparam = None
symmetric_emb = {
    # 'embed_scale': sigma_symm_features,
    'embed_scale_min': sigma_min_symm_features,
    'embed_scale_max': sigma_max_symm_features,
    'embed_dim': n_symm_features,
    'rescale_factor': rescale_factor_symm_features
}
fourier_emb = {
    'embed_scale': sigma_fourier_features,
    'embed_dim': n_fourier_features
} if n_fourier_features > 0 else None

pinn = MODEL(
    num_layers=n_layer - 1,
    hidden_dim=n_hidden,
    out_dim=1,
    symmetric_emb=symmetric_emb,
    fourier_emb=fourier_emb,
    reparam=reparam,
    global_multiplier=global_multiplier,
    global_add_func=log_G_func_cat if args.global_add_logG else None,
    no_embedding_for_first_dim=True,
)

with open(pinn_path, 'rb') as f:
    pinn_variables = pickle.load(f)


if args.pinn_boundary_func == 'log_G':
    def boundary(t, r):
        return log_G_func(t, r)
elif args.pinn_boundary_func == 'S0':
    def boundary(t, r):
        return f_S0_func(t, r)
else:
    raise ValueError(f"Unknown boundary_func: {args.pinn_boundary_func}")

def b_0(t, r):
    u_0 = boundary(t, r)
    return u_0, u_0

def f_0(t, r):
    tr = jnp.concatenate([t, r], axis=-1)
    u_0 = pinn.apply(pinn_variables, tr)
    return u_0, u_0

basis_diag = []

for i in range(n):
    e = jnp.zeros((n,n))
    e = e.at[i,i].set(1)
    basis_diag.append(e)
basis_off_diag = []
for i in range(n):
    for j in range(i+1, n):
        e = jnp.zeros((n,n))
        e = e.at[i,j].set(jnp.sqrt(0.5))
        e = e.at[j,i].set(jnp.sqrt(0.5))
        basis_off_diag.append(e)

basis = basis_diag + basis_off_diag
basis = jnp.stack(basis, axis=0)  # (n_basis, n, n)
n_basis = basis.shape[0]



def expm_sym(A):
    w, Q = jnp.linalg.eigh(A)                 # (..., n), (..., n, n)
    ew = jnp.exp(w)                           # (..., n)
    return (Q * ew[..., None, :]) @ jnp.swapaxes(Q, -1, -2)

def logm_sym(A):
    w, Q = jnp.linalg.eigh(A)                 # (..., n), (..., n, n)
    lw = jnp.log(w)                          # (..., n)
    return (Q * lw[..., None, :]) @ jnp.swapaxes(Q, -1, -2)

def symmetrize(A):
    return (A + jnp.swapaxes(A, -1, -2)) / 2

def clip_eigvals(A, min_eigval=1e-6):
    w, Q = jnp.linalg.eigh(A)                 # (..., n), (..., n, n)
    w_clipped = jnp.clip(w, a_min=min_eigval) # (..., n)
    return (Q * w_clipped[..., None, :]) @ jnp.swapaxes(Q, -1, -2)

def clip_eigvals_norm(A, max_norm = 10):
    w, Q = jnp.linalg.eigh(A)                 # (..., n), (..., n, n)
    w_log = jnp.log(w)
    w_log_norm = jnp.linalg.norm(w_log, axis=-1, keepdims=True)
    w_log_clipped = w_log * jnp.minimum(1.0, max_norm / w_log_norm)
    w_clipped = jnp.exp(w_log_clipped)
    return (Q * w_clipped[..., None, :]) @ jnp.swapaxes(Q, -1, -2)

def t_to_sigma(t):
    return (2 * t) ** 0.5

def sigma_to_t(sigma):
    return 0.5 * sigma ** 2

def sample_gaussian(key, sigma = None, t = None):
    if sigma is not None and t is not None:
        raise ValueError("Only one of sigma or t should be provided.")
    if t is not None:
        sigma = t_to_sigma(t)
    coeffs = jax.random.normal(key, (n_basis,)) * sigma
    mats = jnp.einsum('b,bij->ij', coeffs, basis)
    return expm_sym(mats)

def sample_gaussian_tangent_I(key, sigma=None, t=None):
    if (sigma is not None) and (t is not None):
        raise ValueError("Only one of sigma or t should be provided.")
    if t is not None:
        sigma = t_to_sigma(t)

    coeffs = jax.random.normal(key, (n_basis,)) * sigma
    mats = jnp.einsum('b,bij->ij', coeffs, basis)
    return mats

def distance_to_origin(X):
    # distance from X to I in the SPD manifold
    eigvals, _ = jnp.linalg.eigh(X)
    log_eigvals = jnp.log(eigvals)
    dist = jnp.sqrt(jnp.sum(log_eigvals ** 2, axis=-1))
    return dist

def map_to_I(P, X, L = None):
    # returns L^{-1} P L^{-T}, where X = L L^T
    if L is None:
        L = jnp.linalg.cholesky(X)
    Y = jax.scipy.linalg.solve_triangular(L, P, lower=True)                 # L^{-1} P
    return jax.scipy.linalg.solve_triangular(L, Y.transpose(-1, -2), lower=True).transpose(-1, -2)

def map_from_I(Q, X, L = None):
    # inverse map: L Q L^T
    if L is None:
        L = jnp.linalg.cholesky(X)
    return L @ Q @ L.transpose(-1, -2)

def inner_at_X(X, U, V, L=None, normalized=False):
    if L is None:
        L = jnp.linalg.cholesky(X)
    U_mapped = map_to_I(U, X, L=L)
    V_mapped = map_to_I(V, X, L=L)
    inner = jnp.trace(U_mapped @ V_mapped)
    if normalized:
        inner = inner / ((n* (n+1) / 2) ** 2)
    return inner

def norm_at_X(X, U, L=None, normalized=False):
    if L is None:
        L = jnp.linalg.cholesky(X)
    U_mapped = map_to_I(U, X, L=L)
    norm = jnp.sqrt(jnp.trace(U_mapped @ U_mapped))
    if normalized:
        norm = norm / (n* (n+1) / 2)
    return norm

def exp_spd(U, X, L = None):
    # Exponential map at X of tangent vector U
    # U is in the tangent space at X
    # X = L L^T
    if L is None:
        L = jnp.linalg.cholesky(X)
    Y = map_to_I(U, X, L)
    exp_Y = expm_sym(Y)
    return map_from_I(exp_Y, X, L)

def log_spd(Y, X, L = None):
    # Logarithm map at X of point Y
    # Y is in the manifold
    # X = L L^T
    if L is None:
        L = jnp.linalg.cholesky(X)
    Y_mapped = map_to_I(Y, X, L)
    log_Y = logm_sym(Y_mapped)
    return map_from_I(log_Y, X, L)

if args.bfunc_only:
    def logp_at_I(t, y):
        r = jnp.log(jnp.linalg.eigvalsh(y))
        t = jnp.asarray(t)
        t_batch = t[None]
        logp_b, _ = b_0(t_batch, r)

        return jnp.squeeze(logp_b)
else:
    def logp_at_I(t, y):
        r = jnp.log(jnp.linalg.eigvalsh(y))
        t = jnp.asarray(t)
        t_batch = t[None]
        logp_f, _ = f_0(t_batch, r)
        logp_b, _ = b_0(t_batch, r)

        logp = jnp.where(t > t_boundary, logp_f, logp_b)
        return jnp.squeeze(logp)


def logp_at_X(t, x, y):
    y_moved = map_to_I(y, x)
    y_moved = clip_eigvals(y_moved, min_eigval=args.clip_min_eigval)
    y_moved = clip_eigvals_norm(y_moved, max_norm=args.clip_max_norm)
    return logp_at_I(t, y_moved)

def logp_at_X_0(t, x, y):
    logp = logp_at_X(t, x, y)
    return logp, logp

def logp_at_X_by_Y(t, x, y):
    logp_y, logp = jacrev(logp_at_X_0, argnums=2, has_aux=True)(t, x, y)
    return logp_y, (logp_y, logp)

from functools import partial

@partial(jax.jit, static_argnames=("n_steps",))
def sample_forward(key, t, n_steps=100, prop_t=None):
    if prop_t is None:
        # prop_t = 0.5 * t ** 0.5
        prop_t = jnp.minimum(0.02 * t, 0.003)
    # init
    t = jnp.asarray(t)
    x0 = sample_gaussian(key, t=t)
    logp0 = logp_at_I(t, x0)

    def one_step(carry, _):
        key, x, logp, acc = carry
        key, k_prop, k_u = jax.random.split(key, 3)

        eps = sample_gaussian(k_prop, t=prop_t)
        L = jnp.linalg.cholesky(x)
        x_new = map_from_I(eps,x, L=L)

        logp_new = logp_at_I(t, x_new)

        log_alpha = logp_new - logp  # symmetric RW
        u = jax.random.uniform(k_u)
        accept = jnp.log(u) < jnp.minimum(0.0, log_alpha)

        x_next = jnp.where(accept, x_new, x)
        logp_next = jnp.where(accept, logp_new, logp)
        acc_next = acc + accept.astype(jnp.int32)

        carry_next = (key, x_next, logp_next, acc_next)
        out = (x_next, accept, logp_next)
        return carry_next, out

    carry0 = (key, x0, logp0, jnp.array(0, dtype=jnp.int32))
    carryf, outs = jax.lax.scan(one_step, carry0, xs=None, length=n_steps)

    keyf, xf, logpf, acc = carryf
    xs, accepts, logps = outs  # each has length n_steps

    # include initial state at index 0 to match your original behavior
    samples = jnp.concatenate([x0[None, ...], xs], axis=0)                # (n_steps+1, n, n)
    accepts = jnp.concatenate([jnp.array([True]), accepts], axis=0)       # (n_steps+1,)
    logps = jnp.concatenate([logp0[None], logps], axis=0)                # (n_steps+1,)

    accept_rate = accepts[1:].sum() / jnp.array(n_steps, jnp.float32)
    return samples, accepts, logps, accept_rate


@partial(jax.jit, static_argnames=("prior_steps",))
def sample_prior_mcmc(key, t_max, prior_steps=100):
    path, accepts, logps, acc_rate = sample_forward(key, t=t_max, n_steps=prior_steps)
    xT = path[-1]
    return xT, acc_rate

def fro(a, b):
    return jnp.linalg.norm(a - b, ord="fro", axis=(-2, -1))

def tensor_power(A, r, eps=0.0):
    S, U = jnp.linalg.eigh(A)                       # S: (..., n), U: (..., n, n)
    if eps > 0.0:
        S = jnp.maximum(S, eps)
    Sr = S ** r                                     # (..., n)
    return (U * Sr[..., None, :]) @ jnp.swapaxes(U, -1, -2)

def spd_dis(A, B, eps=0.0):
    A_mhalf = tensor_power(A, -0.5, eps=eps)
    M = A_mhalf @ B @ A_mhalf
    S, _ = jnp.linalg.eigh(M)                        # (..., n)
    if eps > 0.0:
        S = jnp.maximum(S, eps)
    return (jnp.log(S) ** 2).sum(axis=-1)

def _stats(name, x):
    x0 = jnp.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
    jax.debug.print(
        "{name}: finite={finite}  max|x|={mx}  fro={fro}",
        name=name,
        finite=jnp.all(jnp.isfinite(x)),
        mx=jnp.max(jnp.abs(x0)),
        fro=jnp.linalg.norm(x0),
    )

def project_spd(A, *, min_eigval=1e-6, max_log_norm=None, eps=1e-12, max_fro=1e7):
    orig_dtype = A.dtype
    # _stats("A_in", A)
    A = symmetrize(A).astype(jnp.float64)  # <-- big stability win
    # _stats("A_sym64", A)
    # bad = ~jnp.all(jnp.isfinite(A))
    # jax.debug.print("A has nonfinite? {}", bad)
    fro = jnp.linalg.norm(A, axis=(-1, -2), keepdims=True)
    scale = jnp.minimum(1.0, max_fro / (fro + eps))
    A = A * scale
    w, Q = jnp.linalg.eigh(A)
    w_sorted = jnp.sort(w, axis=-1)
    min_gap = jnp.min(w_sorted[..., 1:] - w_sorted[..., :-1], axis=-1)
    # jax.debug.print("eig: min={m} max={M} min_gap={g}",
    #                 m=jnp.min(w), M=jnp.max(w), g=jnp.min(min_gap))
    w = jnp.maximum(w, min_eigval)

    if max_log_norm is not None:
        w_log = jnp.log(w)
        nrm = jnp.linalg.norm(w_log, axis=-1, keepdims=True)
        scale = jnp.minimum(1.0, max_log_norm / (nrm + eps))
        w = jnp.exp(w_log * scale)

    # Gram-form reconstruction: A = (Q*sqrt(w)) (Q*sqrt(w))^T
    sqrt_w = jnp.sqrt(w)
    L = Q * sqrt_w[..., None, :]
    A = L @ jnp.swapaxes(L, -1, -2)

    # Optional tiny symmetrize (should be nearly no-op)
    A = symmetrize(A)

    return A.astype(orig_dtype)

wandb.init(
    project=project_name,
    config=vars(args),
    name=args.exp_name
)
key = jax.random.PRNGKey(args.seed)


# 2. Data Loading & Splitting
print(f"Loading data from {args.dataset_path}...")
train_dataset = CSVDataset(args.dataset_path, args.spd_size)
val_dataset = CSVTestDataset(
    spd_path = 'SPD-DDPM/data/condition/data_true.csv',
    y_path = 'SPD-DDPM/data/condition/test_y.csv',
    spd_size = n
)
from mlp_spd import SPDToSymmetricMLP

train_size = len(train_dataset)
val_size = len(val_dataset)

train_loader = DataLoader(
    train_dataset, 
    batch_size=args.batch_size, 
    shuffle=True, 
    collate_fn=jnp_collate,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=args.batch_size, 
    shuffle=False, 
    collate_fn=jnp_collate,
    drop_last=True
)

print(f"Data split: {train_size} training, {val_size} validation.")

# 3. Model & Init
key = jax.random.PRNGKey(args.seed)
key, init_key = jax.random.split(key)

# Re-instantiate model based on args
if args.model_type == 'spd_mlp':
    denoiser = SPDToSymmetricMLP(
        mat_dim=args.spd_size,
        num_layers=args.num_layers,
        hidden_dim=args.hidden_dim,
        activation='swish',
        n_fourier=args.n_fourier,
        fourier_scale=args.fourier_scale,
        fourier_log_scale=args.fourier_log_scale,
        log_t=args.log_t,
        name='SPDDenoiser'
    )
elif args.model_type == 'spd_net':
    from spd_net import SPDNet
    denoiser = SPDNet(
        spd_size=args.spd_size,
        y_dim=13,
        n_fourier=args.n_fourier,
        fourier_scale=args.fourier_scale,
        spd_normalize=args.spd_normalize
    )
elif args.model_type == 'spd_net2':
    from spd_net2 import SPDNet
    denoiser = SPDNet(
        spd_size=args.spd_size,
        y_dim=13,
        n_fourier=args.n_fourier,
        fourier_scale=args.fourier_scale,
    )
else:
    raise ValueError(f"Unknown model_type: {args.model_type}")

# Dummy input for initialization
x_init = jax.random.normal(init_key, (args.spd_size, args.spd_size))
x_init = x_init @ x_init.T + jnp.eye(args.spd_size) * 1e-3
t_init = jnp.array([0.1])
y_init = jnp.zeros(13)

params = denoiser.init(init_key, x_init, t_init, condition=y_init)

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * args.epochs

print(f"Total training steps: {total_steps}")
warmup_steps = int(0.05 * total_steps) # 5% warmup
    
if args.lr_schedule == 'cosine':
    scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.0,           # Start at 0
        peak_value=args.lr,       # Warm up to this
        warmup_steps=warmup_steps,
        decay_steps=total_steps,
        end_value=0.0             # Decay down to 0
    )
elif args.lr_schedule == 'constant':
    scheduler = optax.constant_schedule(args.lr)
else:
    raise ValueError(f"Unknown lr_schedule: {args.lr_schedule}")

optimizer = optax.chain(
    optax.clip_by_global_norm(args.gradient_clip),  # Clip gradients
    optax.adam(learning_rate=scheduler)             # Apply updates
)
state = TrainState.create(
    apply_fn=denoiser.apply,
    params=params,
    tx=optimizer,
    rng=key
)

# 5. Loss Functions (JIT compiled)

# Your original loss function logic
def loss_func(params, x_0, t, y, key):
    x_0, t, y = x_0.astype(jnp.float64), t.astype(jnp.float64), y.astype(jnp.float64)
    key_sample, key_fn = jax.random.split(key, 2)
    lx_0 = jnp.linalg.cholesky(x_0)

    # Assuming sample_forward returns (path, ...) and we want the last step
    eps = sample_forward(key_sample, t=t[0])[0][-1]
    x_t = map_from_I(eps, x_0, L=lx_0)

    x_t = clip_eigvals(x_t, min_eigval=args.clip_min_eigval)
    x_t = clip_eigvals_norm(x_t, max_norm=args.clip_max_norm)

    _, (score_y, score) = logp_at_X_by_Y(t[0], x_0, x_t)
    # score_riem = x_t @ score_y @ x_t
    score_y = symmetrize(score_y)
    score_riem = symmetrize(x_t @ score_y @ x_t)
    # if args.pulled_score:
    #     score_pulled =  exp_spd(score_riem * 2*t, x_t)
    # Apply model
    denoised = denoiser.apply(params, x_t, t, condition=y)
    denoised = denoised.astype(jnp.float64)
    # import pdb; pdb.set_trace()

    # loss = norm_at_X(x_t, denoised - score_y, L=lx_0, normalized=True) ** 2
    if args.pulled_score:
        # denoised = clip_eigvals(denoised, min_eigval=args.clip_min_eigval)
        # denoised = clip_eigvals_norm(denoised, max_norm=args.clip_max_norm)
        # denoised = symmetrize(denoised)
        denoised = project_spd(denoised, min_eigval=args.clip_min_eigval, max_log_norm=args.clip_max_norm)
        # denoised_mapped = map_to_I(denoised, score_pulled)
        # loss = distance_to_origin(denoised_mapped) ** 2
        normalized = True if args.loss_weight == 'uniform' else False
        denoised = log_spd(denoised, x_t) / (2.0 * t)
        loss = norm_at_X(x_t, denoised - score_riem,normalized=normalized)**2

    else:
        if args.loss_type == 'at_x0':
            loss = norm_at_X(x_t, denoised - score_riem, L = lx_0, normalized=True) ** 2
        elif args.loss_type == 'at_xt':
            loss = norm_at_X(x_t, denoised - score_riem, L = jnp.linalg.cholesky(x_t), normalized=False) ** 2
            # loss = norm_at_X(x_t, denoised - score_riem, L = jnp.linalg.cholesky(x_t), normalized=True) ** 2
        elif args.loss_type == 'at_I':
            loss = norm_at_X(jnp.eye(n), denoised - score_riem, L = jnp.eye(n), normalized=True) ** 2
        else:
            raise ValueError(f"Unknown loss_type: {args.loss_type}")
    if args.loss_weight == 'sigma':
        loss = loss * (t_to_sigma(t[0]) ** 2)
    elif args.loss_weight == 'uniform':
        pass
    elif args.loss_weight == 'inv_sigma':
        loss = loss / (t_to_sigma(t[0]) ** 2)
    else:
        raise ValueError(f"Unknown loss_weight: {args.loss_weight}")
    return loss

# Vectorized loss over batch
vmap_loss_func = jax.vmap(loss_func, in_axes=(None, 0, 0, 0, 0))

def compute_loss(params, x, t, y, keys):
    return jnp.mean(vmap_loss_func(params, x, t, y, keys))

# Training Step
# @jax.jit

def sample_t(key, batch_size, t_min, t_max):
    if args.sample_t_method == 'log_uniform':
        logt = jax.random.uniform(key, (batch_size, 1), minval=jnp.log(t_min), maxval=jnp.log(t_max))
    elif args.sample_t_method == 'uniform':
        logt = jnp.log(jax.random.uniform(key, (batch_size, 1), minval=t_min, maxval=t_max))
    else:
        raise ValueError(f"Unknown sample_t_method: {args.sample_t_method}")
    return jnp.exp(logt)

@jax.jit
def train_step(state, x, y):
    # Split RNG for the batch
    rng, step_rng = jax.random.split(state.rng)

    t = sample_t(step_rng, x.shape[0], args.t_min_train, args.t_max_train)
    
    # Generate keys for each sample in batch
    batch_keys = jax.random.split(step_rng, x.shape[0])
    
    loss_fn = lambda p: compute_loss(p, x, t, y, batch_keys)
    # loss_fn(state.params)  # For debugging
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    grad_norm = optax.global_norm(grads)

    new_state = state.apply_gradients(grads=grads)
    new_state = new_state.replace(rng=rng)
    
    return new_state, loss, grad_norm

# Validation Step
@jax.jit
def eval_step(state, x, y, rng):
    # Deterministic t or random t for validation? Usually random to cover range.

    t = sample_t(rng, x.shape[0], args.t_min_train, args.t_max_train)
    batch_keys = jax.random.split(rng, x.shape[0])
    
    loss = compute_loss(state.params, x, t, y, batch_keys)
    return loss


def accept_by_condition_number(x, max_condition_number=1e3):
    eigvals = jnp.linalg.eigvalsh(x)
    condition_number = jnp.abs(eigvals[-1] / eigvals[-2])
    print('condition_number', condition_number)
    accept = condition_number < max_condition_number
    return accept

def accept_by_distance_to_origin(x, max_distance = 10.0):
    dist = distance_to_origin(x)
    print('distance_to_origin', dist)
    accept = dist < max_distance
    return accept

@partial(jax.jit, static_argnames=("num_steps", "prior_steps", "use_log_grid"))
def sample_reverse_em_spd(
    key,
    params,
    condition_y,
    *,
    t_max,
    t_min,
    num_steps=100,
    prior_steps=100,
    use_log_grid=True,
):
    """
    Reverse-time sampling on SPD(n) with the affine-invariant metric,
    using a geodesic Euler–Maruyama / GRW step:

        W = dt * b_rev(t, X) + sqrt(dt) * sigma * Z   (in TxM)
        X_next = exp_X[W]

    Here forward diffusion is assumed: dX = sqrt(2) dW  (so sigma = sqrt(2))
    => reverse drift uses g^2 = 2: b_rev = 2 * score_theta(t, X).

    condition_y: shape (13,) in your case
    Returns: (x_final, full_path, prior_accept_rate)
    """
    # --- 1) sample prior at t_max via MCMC ---
    key, k_prior = jax.random.split(key)
    x0, prior_acc_rate = sample_prior_mcmc(k_prior, t_max=t_max, prior_steps=prior_steps)
    x0 = x0.astype(jnp.float64)

    # --- 2) build decreasing time grid ---
    if use_log_grid:
        ts = jnp.exp(jnp.linspace(jnp.log(t_max), jnp.log(t_min), num_steps + 1))
    else:
        ts = jnp.linspace(t_max, t_min, num_steps + 1)
    ts = jnp.concatenate([ts, jnp.array([0.0])], axis=0)

    t_pairs = jnp.stack([ts[:-1], ts[1:]], axis=1)  # (num_steps, 2)
    t_pairs = t_pairs.astype(jnp.float64)

    def one_step(carry, t_pair):
        key, x = carry
        t, t_next = t_pair
        dt = t - t_next  # positive

        # stabilize before Cholesky
        x = clip_eigvals(x, min_eigval=args.clip_min_eigval)
        x = clip_eigvals_norm(x, max_norm=args.clip_max_norm)
        L = jnp.linalg.cholesky(x)

        # score network
        # denoiser expects t array-like; keep consistent with training
        t_in = jnp.array([t], dtype=x.dtype)
        score = denoiser.apply(params, x, t_in, condition=condition_y)
        score = score.astype(jnp.float64)
        if args.pulled_score:
            score = project_spd(score, min_eigval=args.clip_min_eigval, max_log_norm=args.clip_max_norm)
            score = log_spd(score, x, L=L) / (2 * t)
        score = 0.5 * (score + score.T)  # safety symmetrization

        # sample *tangent* Gaussian Z ~ N(0, I) at I, then transport to TxM at x
        key, k_noise = jax.random.split(key)
        z_I = sample_gaussian_tangent_I(k_noise, sigma=1.0)  # pre-expm noise
        z_x = map_from_I(z_I, x, L=L)  # now in Tx SPD at x

        # reverse SDE coefficients for forward dX = sqrt(2) dW:
        # drift = 2 * score, noise scale = sqrt(2)
        drift = 2.0 * score
        w = drift * dt + jnp.sqrt(2.0 * dt) * z_x  # tangent increment in TxM
        w, x = w.astype(jnp.float64), x.astype(jnp.float64)

        x_next = exp_spd(w, x, L=L)
        accept = accept_by_condition_number(x_next, max_condition_number=1e3) & accept_by_distance_to_origin(x_next, max_distance=args.clip_max_norm)
        x_when_rejected = x * t_to_sigma(t_next) / t_to_sigma(t)
        x_next = jnp.where(accept, x_next, x_when_rejected)
        # optional clipping after update (helps numerical stability)
        x_next = clip_eigvals(x_next, min_eigval=args.clip_min_eigval)
        x_next = clip_eigvals_norm(x_next, max_norm=args.clip_max_norm)
        x_next = symmetrize(x_next)

        return (key, x_next), x_next

    (keyf, xf), xs = jax.lax.scan(one_step, (key, x0), t_pairs)

    # include initial state in the returned path
    path = jnp.concatenate([x0[None, ...], xs], axis=0)  # (num_steps+1, n, n)
    return xf, path, prior_acc_rate

def one_sample(params, k, y):
    xf, cache, prior_acc = sample_reverse_em_spd(
        k,
        params,
        y,
        t_max=args.t_max_train,
        # t_max = 0.3,
        t_min=args.t_min_train,
        # t_min = 0.01,
        num_steps=1000,
        prior_steps=100,
        use_log_grid=True,
    )
    return xf


batched_sampler = jax.jit(jax.vmap(one_sample, in_axes=(None, 0, 0)))

# 6. Training Loop
print("Starting training...")
for epoch in range(args.epochs):
    # --- Training ---
    train_losses = []
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]") as pbar:
        for x, y in pbar:
            # x and y are numpy arrays from torch dataloader; JAX handles them fine
            state, loss, grad_norm = train_step(state, x, y)
            train_losses.append(loss.item())
            pbar.set_postfix(loss=loss.item())
            
            current_lr = scheduler(state.step)
            wandb.log({
                "train_loss": loss.item(),
                "learning_rate": current_lr,
                "grad_norm": grad_norm.item()
            })

    avg_train_loss = np.mean(train_losses) 

    # --- Validation ---
    val_losses = []
    # Create a specific RNG for validation to keep it somewhat stable or just split
    val_rng = jax.random.PRNGKey(epoch) 
    
    with tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Val]") as pbar:
        for x, y in pbar:
            val_rng, batch_rng = jax.random.split(val_rng)
            loss = eval_step(state, x, y, batch_rng)
            val_losses.append(loss.item())
    
    avg_val_loss = np.mean(val_losses)
    
    # Log Epoch Stats
    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
    wandb.log({
        "epoch": epoch + 1, 
        "avg_train_loss": avg_train_loss, 
        "avg_val_loss": avg_val_loss
    })
    
    # Checkpoint (Optional)
    # if (epoch + 1) % args.save_interval == 0:

    if (epoch + 1) % args.sample_interval == 0:
        print(f"Saving checkpoint at epoch {epoch+1}...")
        checkpoint_path = os.path.join(exp_dir, f"checkpoint_epoch_{epoch+1}.pkl")
        with open(checkpoint_path, 'wb') as f:
            pickle.dump({
                'params': state.params,
                'optimizer_state': state.opt_state,
                'epoch': epoch + 1,
                'rng': state.rng,
                'hparams': vars(args),
            }, f)
        print(f"Generating samples at epoch {epoch+1}...")
        preds = []
        trues = []

        for spds, y in val_loader:

            result = batched_sampler(state.params, jax.random.split(val_rng, y.shape[0]), y)
            preds.append(result)
            trues.append(spds)
        preds = jnp.concatenate(preds, axis=0)  # (n_total, n, n)
        trues = jnp.concatenate(trues, axis=0)  # (n_total, n, n)
        f_dis = []
        f_dis = jax.vmap(fro)(trues, preds)
        f_dis_mean = np.array(f_dis).mean()
        a_dis = spd_dis(trues, preds)
        a_dis_mean = jnp.mean(a_dis)

        f_dis_q0, f_dis_q1, f_dis_q2, f_dis_q3, f_dis_q4 = np.percentile(f_dis, [0, 25, 50, 75, 100])
        a_dis_q0, a_dis_q1, a_dis_q2, a_dis_q3, a_dis_q4 = np.percentile(a_dis, [0, 25, 50, 75, 100])
        wandb.log({
            "epoch": epoch + 1,
            "f_dis_mean": f_dis_mean,
            "a_dis_mean": a_dis_mean,
            "f_dis_q0": f_dis_q0,
            "f_dis_q1": f_dis_q1,
            "f_dis_q2": f_dis_q2,
            "f_dis_q3": f_dis_q3,
            "f_dis_q4": f_dis_q4,
            "a_dis_q0": a_dis_q0,
            "a_dis_q1": a_dis_q1,
            "a_dis_q2": a_dis_q2,
            "a_dis_q3": a_dis_q3,
            "a_dis_q4": a_dis_q4,
        })
wandb.finish()

