# =========================
# Standard library
# =========================
import argparse
import os
import socket
import time
from timeit import default_timer as timer
from typing import Tuple

os.environ["GEOMSTATS_BACKEND"] = "jax"

# =========================
# Third-party
# =========================
import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jr
import optax
import sympy as sp
import wandb
import yaml
from hydra.utils import get_class, instantiate
from jax import grad, jacrev, jit, vmap
from omegaconf import OmegaConf
from tqdm import tqdm
try:
    KeyArray = jax.random.KeyArray
except AttributeError:
    KeyArray = getattr(jax, "Array", jnp.ndarray)
# =========================
# Local project
# =========================
from archs import ModifiedMlp

# =========================
# score_sde
# =========================
import sys
sys.path.append('riemannian-score-sde')
from score_sde.datasets import DataLoader, TensorDataset, random_split
from score_sde.losses import get_ema_loss_step_fn
from score_sde.models.flow import SDEPushForward
from score_sde.utils import ParametrisedScoreFunction, TrainState, batch_mul, restore, save

# =========================
# riemannian_score_sde
# =========================
from riemannian_score_sde.losses import get_dsm_loss_fn
from riemannian_score_sde.utils.normalization import compute_normalization
from riemannian_score_sde.utils.vis import plot, plot_ref


parser = argparse.ArgumentParser(description='Train pinn')
parser.add_argument('--exp_name', type=str, default='tmp',
                    help='exp name')
parser.add_argument('--use_pinn', type=bool, default=True, action=argparse.BooleanOptionalAction,
                    help='use_pinn')
parser.add_argument('--pinn_name', type=str, default='exp022-0',
                    help='use_pinn')
parser.add_argument('--n_max', type=int, default=-1,
                    help='n_max')
parser.add_argument('--beta_f', type=float, default=5,
                    help='beta_f')
parser.add_argument('--data', type=str, default='earthquake',
                    help='data')
parser.add_argument('--seed', type=int, default=None,
                    help='seed')
parser.add_argument('--mcmc_t_thresh', type=float, default=0.,
                    help='mcmc_t_thresh')
args = parser.parse_args()

if args.seed is None:
    args.seed = hash(str(time.time())+args.exp_name) % (2**32)

exp_name = args.pinn_name
eqn_name = 'FP_S2_jax'
with open(f'{eqn_name}/data-pinn/{exp_name}/args.yaml', 'r') as f:
    pinn_args = yaml.load(f,Loader=yaml.FullLoader)
    pinn_args = argparse.Namespace(**pinn_args)
variables = jnp.load(f'{eqn_name}/data-pinn/{exp_name}/last_epoch.npy',allow_pickle=True).item()

pinn  = ModifiedMlp(num_layers = pinn_args.n_layers-1,
                    hidden_dim = pinn_args.n_hidden,
                    out_dim = 1,
                    fourier_emb = {
                        'embed_scale':pinn_args.sigma_fourier,
                        'embed_dim':pinn_args.n_fourier
                    },
                    reparam = {
                        'type':'weight_fact',
                        'mean':1.,
                        'stddev':0.1,
                    }
                    )
def log_print(msg: str):
    print(msg, flush=True)

def f_0(variables,x,t):
    xt = jnp.concatenate([x,t],axis=-1)
    u_0 = pinn.apply(variables, xt)
    return u_0,u_0

def f_t(variables,x,t):
    u_t,u_0 = jacrev(f_0,argnums=2,has_aux=True)(variables,x,t)
    return u_t,(u_t,u_0)

def f_x(variables,x,t):
    u_x,u_0 = jacrev(f_0,argnums=1,has_aux=True)(variables,x,t)
    return u_x,(u_x,u_0)

def f_xx(variables,x,t):
    u_xx,(u_x,u_0) = jacrev(f_x,argnums=1,has_aux=True)(variables,x,t)
    return u_xx, (u_xx,u_x,u_0)


import sympy as sp

r = sp.symbols('r')
D = sp.sin(r) / r
DD = (1/D*sp.diff(D,r)).cancel().simplify()

def lap(u):
    return -sp.diff(u,r,2) - (1/r + DD)* sp.diff(u,r)

s = sp.symbols('s')
def iterate(u,i):
    integrand = (sp.sqrt(D)*lap(u)).cancel().subs(r,s) * s ** (i-1)
    integral = sp.integrate(integrand.cancel(),s,heurisch=True).cancel().simplify()
    # integral = integral.subs(s,r) - integral.limit(s,0)
    integral = integral.subs(s,r)
    return (- 1*(r**(-i)) / sp.sqrt(D) * integral).cancel().simplify()


s = sp.symbols('s')
def iterate_approx(u,i):
    integrand = (sp.sqrt(D)*lap(u)).cancel().subs(r,s) * s ** (i-1)
    integrand = integrand.series(s,0,10).removeO().evalf()
    integral = sp.integrate(integrand.cancel(),s).cancel().simplify()
    # integral = integral.subs(s,r) - integral.limit(s,0)
    integral = integral.subs(s,r)
    return (- 1*(r**(-i)) / sp.sqrt(D) * integral).cancel().simplify().series(r,0,10).removeO()


u_0 = sp.Piecewise((1/sp.sqrt(D), r != 0), (1, r == 0))
u_1 = iterate(u_0,1)
u_2 = iterate_approx(u_1,2)
u_3 = iterate_approx(u_2,3)
# D = sp.Piecewise((sp.sin(r) / r, r != 0), (1, r == 0))

t = sp.symbols('t')
# G = (2*sp.pi*t)**(-1) * sp.exp(-r**2/(2*t))
G = (4*sp.pi*t)**(-1) * sp.exp(-r**2/(4*t))
logG = -sp.log(4 * sp.pi * t) - r**2/(4*t)

S0 = G * u_0
S1 = G * (u_0 + u_1*t)
S2 = G * (u_0 + u_1*t + u_2*t**2)
S3 = G * (u_0 + u_1*t + u_2*t**2 + u_3*t**3)

logS0 = logG + sp.log(u_0)
logS1 = logG + sp.log(u_0 + u_1*t)
logS2 = logG + sp.log(u_0 + u_1*t + u_2*t**2)
logS3 = logG + sp.log(u_0 + u_1*t + u_2*t**2 + u_3*t**3)

# x = sp.symbols('x:3')
# acos = sp.Max(sp.Min(sp.acos(x[2]), sp.pi - 1e-4), 1e-4)
# f_logS3 = sp.lambdify((x, t),logS3.subs(r, acos).simplify(),modules = "jax")

f_logS3 = sp.lambdify((r,t),logS3.simplify().evalf(),modules = "jax")


def project(x, v):
    return v - jnp.sum(x * v, axis=-1, keepdims=True) * x

def exp(x, v):
    theta = jnp.linalg.norm(v, ord=2, axis=-1, keepdims=True)
    return jnp.cos(theta) * x + jnp.sin(theta) * v / theta

def orthonormal(v):
    """
    Given a 3D unit vector v, compute an orthonormal matrix that maps v to [0, 0, 1].
    """
    # Ensure v is a unit vector
    v = v / jnp.linalg.norm(v)
    
    # Target vector
    z_axis = jnp.array([0.0, 0.0, 1.0])
    
    # Compute rotation axis (cross product of v and z_axis)
    axis = jnp.cross(v, z_axis)
    axis_norm = jnp.linalg.norm(axis)
    
    def aligned_case():
        return jnp.eye(3)
    
    def anti_aligned_case():
        return jnp.diag(jnp.array([-1., 1., -1.]))  # 180-degree rotation
    
    def general_case():
        axis_unit = axis / axis_norm  # Normalize rotation axis
        angle = jnp.arccos(jnp.dot(v, z_axis))  # Compute rotation angle
        
        # Rodrigues' rotation formula for rotation matrix
        K = jnp.array([[0, -axis_unit[2], axis_unit[1]],
                       [axis_unit[2], 0, -axis_unit[0]],
                       [-axis_unit[1], axis_unit[0], 0]])
        
        R = jnp.eye(3) + jnp.sin(angle) * K + (1 - jnp.cos(angle)) * (K @ K)
        return R
    
    R = jax.lax.cond(axis_norm < 1e-6, 
                 lambda: jax.lax.cond(v[2] > 0, aligned_case, anti_aligned_case),
                 general_case)
    
    return R

t_threshold = 0.3

def compute_logp_pinn(x_0,x,t):
    R = orthonormal(x_0)
    _,u = f_0(variables,jnp.dot(R,x), t)
    return u[...,0]
    
def compute_logp_sp(x_0,x,t):
    r = jnp.clip(jnp.arccos(jnp.dot(x_0,x)), 1e-7, jnp.pi - 0.001)
    u = f_logS3(r,t[...,0])
    return u

def compute_logp_grad_pinn(x_0,x,t):
    return project(x,grad(compute_logp_pinn,argnums=1)(x_0,x,t))

def compute_logp_grad_sp(x_0,x,t):
    return project(x,grad(compute_logp_sp,argnums=1)(x_0,x,t))

@jit
def compute_logp(x_0,x,t):
    u_pinn = vmap(compute_logp_pinn)(x_0,x,t)
    u_sp = vmap(compute_logp_sp)(x_0,x,t)
    u = jnp.where(t[:,0]>t_threshold,u_pinn,u_sp)
    return u

@jit
def compute_logp_grad(x_0,x,t):
    u_grad_pinn = vmap(compute_logp_grad_pinn)(x_0,x,t)
    u_grad_sp = vmap(compute_logp_grad_sp)(x_0,x,t)
    u_grad = jnp.where(t>t_threshold,u_grad_pinn,u_grad_sp)
    u_grad = jnp.where(jnp.abs(jnp.sum(x*x_0,axis=1,keepdims=True)) > 1-1e-7, jnp.zeros_like(u_grad), u_grad)
    return u_grad



def compute_mcmc(key, x_0, t, n_iter=25, t_thresh = 0.):
    n_data = t.shape[0]
    
    x = exp(x_0, project(x_0, jr.normal(key, (n_data, 3)) * (t * 2) ** 0.5))
    logp_x = compute_logp(x_0, x, t)
    
    def body_fn(i, state):
        key, x, logp_x = state
        key, subkey = jax.random.split(key)
        x_new = exp(x, project(x, jr.normal(subkey, (n_data, 3)) * (t * 2) ** 0.5))
        logp_x_new = compute_logp(x_0, x_new, t)
        accept_ratio = jnp.exp(logp_x_new - logp_x)
        accept = jr.uniform(subkey, (n_data,)) < accept_ratio
        x = jnp.where(accept[:, None], x_new, x)
        logp_x = jnp.where(accept, logp_x_new, logp_x)
        return key, x, logp_x
    
    key, x, logp_x = jax.lax.fori_loop(0, n_iter, body_fn, (key, x, logp_x))
    
    x = jnp.where(t< t_thresh, x_0, x)
    logp_x = jnp.where(t[:,0] < t_thresh, 1.0, logp_x)
    
    return x, logp_x
mcmc_sample = jit(compute_mcmc, static_argnames=["n_iter", "t_thresh"])


def sample_sphere(key, n_sample):
    subkey1, subkey2 = jr.split(key)
    z = jr.uniform(subkey1, (n_sample,), minval=-1.0, maxval=1.0)
    phi = jr.uniform(subkey2, (n_sample,)) * 2 * jnp.pi
    rho = jnp.sqrt(1 - z**2)
    return jnp.stack([jnp.sin(phi) * rho, jnp.cos(phi) * rho, z], axis=1)

import jax.numpy as jnp

def spherical_grid(grid_size):
    z_values = jnp.linspace(-1, 1, grid_size)  # z ranges from -1 to 1
    phi_values = jnp.linspace(0, 2 * jnp.pi, grid_size)  # phi ranges from 0 to 2*pi
    
    Z, PHI = jnp.meshgrid(z_values, phi_values, indexing='ij')
    
    R = jnp.sqrt(1 - Z**2)  # Compute radius at each z
    X = R * jnp.cos(PHI)
    Y = R * jnp.sin(PHI)
    
    return jnp.stack([X.flatten(), Y.flatten(), Z.flatten()], axis=-1)


def get_dsm_loss_fn_pinn(
    pushforward: SDEPushForward,
    model: ParametrisedScoreFunction,
    train: bool = True,
    like_w: bool = True,
    eps: float = 1e-3,
    s_zero=True,
    **kwargs
):
    sde = pushforward.sde

    def loss_fn(
        rng: KeyArray, params: dict, states: dict, batch: dict
    ) -> Tuple[float, dict]:
        score_fn = sde.reparametrise_score_fn(model, params, states, train, True)
        y_0, context = pushforward.transform.inv(batch["data"]), batch["context"]

        rng, step_rng = jr.split(rng)
        # uniformly sample from SDE timeframe
        t = jr.uniform(step_rng, (y_0.shape[0],), minval=sde.t0 + eps, maxval=sde.tf)
        rng, step_rng = jr.split(rng)

        # sample p(y_t | y_0)
        # compute $\nabla \log p(y_t | y_0)$
        if s_zero:  # l_{t|0}
            # y_t = sde.marginal_sample(step_rng, y_0, t)
            # print('y_0,y_t,t',y_0.shape,y_t.shape,t.shape)
            y_t,_ = mcmc_sample(step_rng, y_0,sde.beta_schedule.rescale_t(t[:,None]) / 2, t_thresh = args.mcmc_t_thresh)    
            # if "n_max" in kwargs and kwargs["n_max"] <= -1:
            #     get_logp_grad = lambda y_0, y_t, t: sde.varhadan_exp(
            #         y_0, y_t, jnp.zeros_like(t), t
            #     )[1]
            # else:
            #     get_logp_grad = lambda y_0, y_t, t: sde.grad_marginal_log_prob(
            #         y_0, y_t, t, **kwargs
            #     )[1]
            # logp_grad = get_logp_grad(y_0, y_t, t)
            # print('logp_grad',logp_grad.shape)
            logp_grad = compute_logp_grad(y_0, y_t, sde.beta_schedule.rescale_t(t[:,None]) / 2)
            std = jnp.expand_dims(sde.marginal_prob(jnp.zeros_like(y_t), t)[1], -1)
        else:  # l_{t|s}
            y_t, y_hist, timesteps = sde.marginal_sample(
                step_rng, y_0, t, return_hist=True
            )
            y_s = y_hist[-2]
            delta_t, logp_grad = sde.varhadan_exp(y_s, y_t, timesteps[-2], timesteps[-1])
            delta_t = t  # NOTE: works better?
            std = jnp.expand_dims(sde.marginal_prob(jnp.zeros_like(y_t), delta_t)[1], -1)
        
        # compute approximate score at y_t
        
        score, new_model_state = score_fn(y_t, t, context, rng=step_rng)
        score = score.reshape(y_t.shape)

        if not like_w:
            score = batch_mul(std, score)
            logp_grad = batch_mul(std, logp_grad)
            losses = sde.manifold.metric.squared_norm(score - logp_grad, y_t)
        else:
            # compute $E_{p{y_0}}[|| s_\theta(y_t, t) - \nabla \log p(y_t | y_0)||^2]$
            g2 = sde.coefficients(jnp.zeros_like(y_0), t)[1] ** 2
            losses = sde.manifold.metric.squared_norm(score - logp_grad, y_t) * g2

        loss = jnp.mean(losses)
        return loss, new_model_state

    return loss_fn


# log = logging.getLogger(__name__)

cfg = OmegaConf.load("riemannian-score-sde/configs/config_s2.yaml")

cfg.dataset.data_dir = 'data'
cfg.flow.beta_schedule.beta_f = args.beta_f
path = f'experiment_results/{args.exp_name}'
os.makedirs(path, exist_ok=True)

wandb.init(
    project='s2-diffusion',
    config=vars(args),
    name=args.exp_name
)
with open(os.path.join(path,'args.yaml'),'w') as f:
    yaml.dump(vars(args),f)


if args.data == 'earthquake':
    cfg.dataset._target_ = 'riemannian_score_sde.datasets.earth.Earthquake'
    cfg.dataset.name = 'earthquake'
elif args.data == 'volcano':
    cfg.dataset._target_ = 'riemannian_score_sde.datasets.earth.VolcanicErruption'
    cfg.dataset.name = 'volcano'
elif args.data == 'fire':
    cfg.dataset._target_ = 'riemannian_score_sde.datasets.earth.Fire'
    cfg.dataset.name = 'fire'
elif args.data == 'flood':
    cfg.dataset._target_ = 'riemannian_score_sde.datasets.earth.Flood'
    cfg.dataset.name = 'flood'
else:
    raise ValueError('data not found')

cfg.seed = args.seed

cfg.ckpt_dir = f"{path}/ckpt" 
def train(train_state):
    if args.use_pinn:
        loss = get_dsm_loss_fn_pinn(
            like_w=False,
            eps=1e-3,
            pushforward=pushforward,
            model=model,
            train=True,
            thresh=0.,
            n_max=-1,
        )
    else:
        loss = get_dsm_loss_fn(
            like_w=False,
            eps=1e-3,
            pushforward=pushforward,
            model=model,
            train=True,
            thresh=0.,
            n_max=args.n_max,
        )

    train_step_fn = get_ema_loss_step_fn(loss, optimizer=optimiser, train=True)
    train_step_fn = jax.jit(train_step_fn)

    rng = train_state.rng
    tbar = tqdm(
        range(train_state.step, cfg.steps),
        total=cfg.steps - train_state.step,
        bar_format="{desc}{bar}{r_bar}",
        mininterval=1,
    )

    train_time = timer()
    total_train_time = 0.0

    for step in tbar:
        data, context = next(train_ds)
        batch = {"data": data, "context": context}

        rng, next_rng = jax.random.split(rng)
        (rng, train_state), loss_val = train_step_fn((next_rng, train_state), batch)

        if jnp.isnan(loss_val).any():
            log_print("[WARN] Loss is NaN — stopping training.")
            return train_state, False

        if step % 50 == 0:
            loss_float = float(loss_val)
            wandb.log({"train/loss": loss_float}, step=step)
            tbar.set_description(f"Loss: {loss_float:.3f}")

        if step > 0 and step % cfg.val_freq == 0:
            time_per_it = (timer() - train_time) / cfg.val_freq
            wandb.log({"train/time_per_it": float(time_per_it)}, step=step)

            total_train_time += timer() - train_time
            save(ckpt_path, train_state)

            eval_time = timer()
            if cfg.train_val:
                evaluate(train_state, "val", step)
                wandb.log({"val/time_per_it": float(timer() - eval_time)}, step=step)

            if cfg.train_plot:
                try:
                    generate_plots(train_state, "val", step=step)
                except Exception as e:
                    log_print(f"[WARN] Failed to generate plots: {e}")

            train_time = timer()

    wandb.log({"train/total_time": float(total_train_time)}, step=cfg.steps - 1)
    save(ckpt_path, train_state)
    return train_state, True

def evaluate(train_state, stage, step=None):
    log_print(f"Running evaluation: {stage}")
    dataset = eval_ds if stage == "val" else test_ds

    model_w_dicts = (model, train_state.params_ema, train_state.model_state)
    likelihood_fn = pushforward.get_log_prob(model_w_dicts, train=False)
    likelihood_fn = jax.jit(likelihood_fn)

    logp, nfe, N = 0.0, 0.0, 0

    if hasattr(dataset, "__len__"):
        for batch in dataset:
            logp_step, nfe_step = likelihood_fn(*batch)
            logp += logp_step.sum()
            nfe += nfe_step
            N += logp_step.shape[0]
        denom = len(dataset)
    else:
        dataset.batch_dims = [cfg.eval_batch_size]
        samples = round(20_000 / cfg.eval_batch_size)
        for _ in range(samples):
            batch = next(dataset)
            logp_step, nfe_step = likelihood_fn(*batch)
            logp += logp_step.sum()
            nfe += nfe_step
            N += logp_step.shape[0]
        dataset.batch_dims = [cfg.batch_size]
        denom = samples

    logp = logp / N
    nfe = nfe / denom

    log_print(f"{stage}/logp = {float(logp):.3f}")
    log_print(f"{stage}/nfe  = {float(nfe):.1f}")
    wandb.log({f"{stage}/logp": float(logp), f"{stage}/nfe": float(nfe)}, step=step)

    if stage == "test":
        default_context = None
        Z = compute_normalization(likelihood_fn, data_manifold, context=default_context)
        log_print(f"{stage}/Z = {float(Z):.2f}")
        wandb.log({f"{stage}/Z": float(Z)}, step=step)

def generate_plots(train_state, stage, step=None):
    log_print(f"Generating plots: {stage}")
    rng = jax.random.PRNGKey(cfg.seed)
    dataset = eval_ds if stage == "eval" else test_ds

    M = 32 if isinstance(pushforward, SDEPushForward) else 8
    model_w_dicts = (model, train_state.params_ema, train_state.model_state)
    sampler_kwargs = dict(N=100, eps=cfg.eps, predictor="GRW")
    sampler = pushforward.get_sampler(model_w_dicts, train=False, **sampler_kwargs)

    x0, context = next(dataset)
    shape = (int(cfg.batch_size * M),)
    rng, next_rng = jax.random.split(rng)
    x = sampler(next_rng, shape, context)
    prop_in_M = float(data_manifold.belongs(x, atol=1e-4).mean())
    log_print(f"Prop samples in M = {100 * prop_in_M:.1f}%")
    wandb.log({f"{stage}/prop_in_M": prop_in_M}, step=step)

    likelihood_fn = pushforward.get_log_prob(model_w_dicts, train=False)
    log_prob = jax.jit(lambda x: likelihood_fn(x)[0])

    fig = plot(data_manifold, None, x, log_prob=log_prob)
    wandb.log({f"{stage}/x0_bwd": wandb.Image(fig)}, step=step)

    if step is not None and step <= 0:
        dataset.batch_dims = shape[0]
        x0 = next(dataset)[0]
        log_prob_data = dataset.log_prob if hasattr(dataset, "log_prob") else None
        fig = plot(data_manifold, None, x0, log_prob=log_prob_data)
        wandb.log({f"{stage}/x0_data": wandb.Image(fig)}, step=step)
        dataset.batch_dims = cfg.batch_size

    if (step is not None and step <= 0) and isinstance(pushforward, SDEPushForward):
        sampler_fwd = pushforward.get_sampler(
            model_w_dicts, train=False, reverse=False, **sampler_kwargs
        )
        zT = sampler_fwd(rng, None, context, z=transform.inv(x0))
        fig = plot_ref(model_manifold, transform.inv(zT), log_prob=base.log_prob)
        wandb.log({f"{stage}/xT_fwd": wandb.Image(fig)}, step=step)

### Main
print("Stage : Startup")
print(f"Jax devices: {jax.devices()}")
run_path = os.getcwd()
print(f"run_path: {run_path}")
print(f"hostname: {socket.gethostname()}")
ckpt_path = os.path.join(run_path, cfg.ckpt_dir)
os.makedirs(ckpt_path, exist_ok=True)
# logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True))

print("Stage : Instantiate model")
rng = jax.random.PRNGKey(cfg.seed)
data_manifold = instantiate(cfg.manifold)
transform = instantiate(cfg.transform, data_manifold)
model_manifold = transform.domain
beta_schedule = instantiate(cfg.beta_schedule)
flow = instantiate(cfg.flow, manifold=model_manifold, beta_schedule=beta_schedule)
base = instantiate(cfg.base, model_manifold, flow)
pushforward = instantiate(cfg.pushf, flow, base, transform=transform)

print("Stage : Instantiate dataset")
rng, next_rng = jax.random.split(rng)
dataset = instantiate(cfg.dataset, rng=next_rng)

if isinstance(dataset, TensorDataset):
    # split and wrapp dataset into dataloaders
    train_ds, eval_ds, test_ds = random_split(
        dataset, lengths=cfg.splits, rng=next_rng
    )
    train_ds, eval_ds, test_ds = (
        DataLoader(train_ds, batch_dims=cfg.batch_size, rng=next_rng, shuffle=True),
        DataLoader(eval_ds, batch_dims=cfg.eval_batch_size, rng=next_rng),
        DataLoader(test_ds, batch_dims=cfg.eval_batch_size, rng=next_rng),
    )
    print(
        f"Train size: {len(train_ds.dataset)}. Val size: {len(eval_ds.dataset)}. Test size: {len(test_ds.dataset)}"
    )
else:
    train_ds, eval_ds, test_ds = dataset, dataset, dataset

print("Stage : Instantiate vector field model")

def model(y, t, context=None):
    """Vector field s_\theta: y, t, context -> T_y M"""
    output_shape = get_class(cfg.generator._target_).output_shape(model_manifold)
    score = instantiate(
        cfg.generator,
        cfg.architecture,
        cfg.embedding,
        output_shape,
        manifold=model_manifold,
    )
    # TODO: parse context into embedding map
    if context is not None:
        t_expanded = jnp.expand_dims(t.reshape(-1), -1)
        if context.shape[0] != y.shape[0]:
            context = jnp.repeat(jnp.expand_dims(context, 0), y.shape[0], 0)
        context = jnp.concatenate([t_expanded, context], axis=-1)
    else:
        context = t
    return score(y, context)

model = hk.transform_with_state(model)

rng, next_rng = jax.random.split(rng)
t = jnp.zeros((cfg.batch_size, 1))
data, context = next(train_ds)
params, state = model.init(rng=next_rng, y=transform.inv(data), t=t, context=context)

print("Stage : Instantiate optimiser")
schedule_fn = instantiate(cfg.scheduler)
optimiser = optax.chain(instantiate(cfg.optim), optax.scale_by_schedule(schedule_fn))
opt_state = optimiser.init(params)

if cfg.resume or cfg.mode == "test":  # if resume or evaluate
    train_state = restore(ckpt_path)
else:
    rng, next_rng = jax.random.split(rng)
    train_state = TrainState(
        opt_state=opt_state,
        model_state=state,
        step=0,
        params=params,
        ema_rate=cfg.ema_rate,
        params_ema=params,
        rng=next_rng,  # TODO: we should actually use this for reproducibility
    )
    save(ckpt_path, train_state)

if cfg.mode == "train" or cfg.mode == "all":
    # if train_state.step == 0 and cfg.test_test:
    #     evaluate(train_state, "test", step=cfg.steps)
    if train_state.step == 0 and cfg.test_plot:
        try:
            generate_plots(train_state, "test", step=-1)
        except Exception as e:
            log_print(f"[WARN] Failed to generate plots: {e}")
    print("Stage : Training")
    train_state, success = train(train_state)
if cfg.mode == "test" or (cfg.mode == "all" and success):
    print("Stage : Test")
    if cfg.test_val:
        evaluate(train_state, "val", step=cfg.steps)
    if cfg.test_test:
        evaluate(train_state, "test", step=cfg.steps)
    if cfg.test_plot:
        try:
            generate_plots(train_state, "test", step=cfg.steps)
        except Exception as e:
            log_print(f"[WARN] Failed to generate plots: {e}")
    success = True






