import numpy as np
import sympy as sp
import os
os.environ["LD_LIBRARY_PATH"] = ""

import jax
import jax.numpy as jnp
from jax import grad, jacrev, vmap, jit, value_and_grad
from jax.tree_util import tree_map
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import freeze
import optax
from archs import ModifiedMlp, MlpBlock
import argparse

from tqdm import tqdm
import wandb
import pickle
from datetime import datetime
import hashlib

# I = np.eye(n)

argparser = argparse.ArgumentParser()
argparser.add_argument('--exp_name', type=str, default='test')
argparser.add_argument('--n', type=int, default=20)

argparser.add_argument('--seed', type=int, default=None)

argparser.add_argument('--t_0', type=float, default=0.01)
# argparser.add_argument('--t_max', type=float, default=1)
argparser.add_argument('--t_boundary', type=float, default=0.1)
argparser.add_argument('--t_bundle', type=int, default=10)
argparser.add_argument('--r_boundary', type=float, default=1.)
argparser.add_argument('--r_max', type=float, default=10.)
argparser.add_argument('--t_max_per_phase', type=float, nargs='+', default=[0., 1.])

argparser.add_argument('--n_layer', type=int, default=4)
argparser.add_argument('--n_hidden', type=int, default=256)
argparser.add_argument('--n_symm_features', type=int, default=256)
argparser.add_argument('--n_fourier_features', type=int, default=0)
argparser.add_argument('--sigma_fourier_features', type=float, nargs='+', default=[1.0])
# argparser.add_argument('--sigma_fourier', type=float, default=1.0)
# argparser.add_argument('--sigma_symm_features', type=float, nargs='+', default=[1.0])
argparser.add_argument('--sigma_min_symm_features', type=float, nargs='+', default=[1.0])
argparser.add_argument('--sigma_max_symm_features', type=float, nargs='+', default=[10.0])
argparser.add_argument('--rescale_factor_symm_features', type=float, default=0.)
argparser.add_argument('--global_add_logG', default = False, action='store_true')
argparser.add_argument('--reparam_type', default= 'none', type=str, choices=['weight_fact', 'none'])
argparser.add_argument('--boundary_func', type=str, default='log_G', choices=['log_G', 'S0'])

argparser.add_argument('--n_data', type=int, default=1024)

argparser.add_argument('--n_boundary', type=int, default=256)
argparser.add_argument('--iterations', type=int, default=100000)
argparser.add_argument('--log_freq', type=int, default=1000)
argparser.add_argument('--loss_weight_pde', type=float, default=1.0)
argparser.add_argument('--loss_weight_b', type=float, default=1.0)
argparser.add_argument('--phases', type=int, default=2)
argparser.add_argument('--iterations_per_phase', type=int, nargs='+', default=[10000, 100000])
argparser.add_argument('--warmup_per_phase', type=int, nargs='+', default=[0, 10000])
argparser.add_argument('--lr_per_phase', type=float, nargs='+', default=[1e-3, 1e-4])
argparser.add_argument('--gradient_clip', type=float, default=1.0)
argparser.add_argument('--normalize_boundary_loss', default = False, action='store_true')


argparser.add_argument('--lr', type=float, default=1e-3)

argparser.add_argument('--ema_beta', type=float, default=0.99)

args = argparser.parse_args()


timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.exp_name = f"{args.exp_name}_{timestamp}"
if args.seed is None:
    args.seed = int(hashlib.sha256(args.exp_name.encode('utf-8')).hexdigest(), 16) % (2**31 - 1)
n = args.n

n_layer = args.n_layer
n_hidden = args.n_hidden
n_symm_features = args.n_symm_features
sigma_min_symm_features = args.sigma_min_symm_features
sigma_max_symm_features = args.sigma_max_symm_features
rescale_factor_symm_features = args.rescale_factor_symm_features
# global_multiplier = 4 * n ** 2
global_multiplier = 1.
reparam_type = args.reparam_type 

n_fourier_features = args.n_fourier_features
sigma_fourier_features = args.sigma_fourier_features

t_bundle = args.t_bundle
t_0 = args.t_0
# t_max = args.t_max
t_boundary = args.t_boundary
r_boundary = args.r_boundary
r_max = args.r_max
r_step = np.exp(np.linspace(np.log(1.0), np.log(r_max), 3))
t_max_per_phase = args.t_max_per_phase

n_data = args.n_data
n_boundary = args.n_boundary
n_interior = n_data - n_boundary

phases = args.phases
iterations_per_phase = args.iterations_per_phase
warmup_per_phase = args.warmup_per_phase
lr_per_phase = args.lr_per_phase
log_freq = args.log_freq
gradient_clip = args.gradient_clip
normalize_boundary_loss = args.normalize_boundary_loss

ema_beta = args.ema_beta

weight_pde = args.loss_weight_pde
weight_b = args.loss_weight_b
weight_pde, weight_b = weight_pde / (weight_pde + weight_b), weight_b / (weight_pde + weight_b)

project_name = "spd-pinn"
exp_dir = os.path.join(f"experiment-{project_name}", args.exp_name)
os.makedirs(exp_dir, exist_ok=True)

# lr = args.lr


def log_G_func(t, r):
    const = -n*(n+1) / 4 * jnp.log(4 * jnp.pi * t)
    # const = jnp.log(4 * jnp.pi * t)
    log_G = const - jnp.sum(r**2,axis=-1, keepdims=True) / (4 * t)
    return log_G

def log_G_func_cat(tr):
    t = tr[...,:1]
    r = tr[...,1:]
    return log_G_func(t, r)


# x = sp.symbols('x')
# log_sinh_by_x_func = sp.log(sp.sinh(x)) - sp.log(x)
# log_sinh_by_x_func = log_sinh_by_x_func.series(x, 0, 10).removeO().evalf()
# log_sinh_by_x_func = sp.lambdify(x, log_sinh_by_x_func, 'jax')

def log_sinh_by_x_func(x):
    y = jnp.log(jnp.sinh(x)/x)
    cond = jnp.abs(x)>1e-5
    y = jnp.where(cond, y, 0)
    return y


# def log_D_func(t, r):
#     r_i = r[:, :, None]
#     r_j = r[:, None, :]
#     r_ij = 0.5 * (r_i - r_j)
#     mask = jnp.triu(jnp.ones((n, n)), k=1)[None, :, :]
#     log_D = jnp.sum(mask * log_sinh_by_x_func(r_ij), axis=(1, 2))
#     return log_D[:,None]

def log_D_func(t, r):
    # Make r batched: (n,) -> (1, n)
    r_b = r[None, :] if r.ndim == 1 else r
    b, n = r_b.shape

    r_ij = 0.5 * (r_b[:, :, None] - r_b[:, None, :])  # (b, n, n)
    mask = jnp.triu(jnp.ones((n, n), dtype=r_b.dtype), k=1)[None]
    log_D_b = jnp.sum(mask * log_sinh_by_x_func(r_ij), axis=(1, 2))  # (b,)

    # Broadcast to batch size, then reshape to match t's shape
    log_D_b = jnp.broadcast_to(log_D_b, (b,))
    return jnp.reshape(log_D_b, t.shape) if t.ndim != 0 else log_D_b.reshape(())


f_S0_func = lambda t, r: log_G_func(t, r) - 0.5 * log_D_func(t, r)

# f_S0_func = lambda t, r: log_G_func(t, r) 

MODEL = ModifiedMlp
# MODEL = MlpBlock
if reparam_type == 'weight_fact':
    reparam = {
        'type': 'weight_fact',
        'mean': 1.,
        'stddev': 0.1,
    }
else:
    reparam = None
symmetric_emb = {
    # 'embed_scale': sigma_symm_features,
    'embed_scale_min': sigma_min_symm_features,
    'embed_scale_max': sigma_max_symm_features,
    'embed_dim': n_symm_features,
    'rescale_factor': rescale_factor_symm_features
}
fourier_emb = {
    'embed_scale': sigma_fourier_features,
    'embed_dim': n_fourier_features
} if n_fourier_features > 0 else None

model = MODEL(
    num_layers=n_layer - 1,
    hidden_dim=n_hidden,
    out_dim=1,
    symmetric_emb=symmetric_emb,
    fourier_emb=fourier_emb,
    reparam=reparam,
    global_multiplier=global_multiplier,
    global_add_func=log_G_func_cat if args.global_add_logG else None,
    no_embedding_for_first_dim=True,
)

init_key = jax.random.PRNGKey(0)

r_init = jax.random.normal(init_key, (1, n))
t_init = jnp.array([[1.]])

variables = model.init(

    jax.random.PRNGKey(42),
    jnp.concatenate([r_init, t_init], axis=-1)
)
variables = freeze(variables)

NO_GRAD_PREFIXES = ("AntiSymmetricEmbs", "SymmetricEmbs", "FourierEmbs")

def _path_item_to_key(p):
    # DictKey / GetAttrKey / SequenceKey etc.
    if hasattr(p, "key"):
        return p.key
    if hasattr(p, "name"):
        return p.name
    if hasattr(p, "idx"):
        return p.idx
    return str(p)

def build_trainable_mask(variables):
    def leaf_mask(path, leaf):
        keys = [_path_item_to_key(p) for p in path]
        # expect something like: 'params' -> 'Dense_0' / 'SymmetricEmbs_0' -> ...
        if "params" in keys:
            i = keys.index("params")
            if i + 1 < len(keys):
                module = str(keys[i + 1])
                if any(module.startswith(pref) for pref in NO_GRAD_PREFIXES):
                    return False
        return True

    return jax.tree_util.tree_map_with_path(leaf_mask, variables)

trainable_mask = build_trainable_mask(variables)

def stop_grad_frozen(tree):
    # frozen leaves become constants for autodiff
    return jax.tree_util.tree_map(
        lambda x, m: x if m else jax.lax.stop_gradient(x),
        tree, trainable_mask
    )

def zero_out_frozen_grads(grads):
    # ensures norms + updates ignore frozen leaves even if something leaks through
    return jax.tree_util.tree_map(
        lambda g, m: g if m else jnp.zeros_like(g),
        grads, trainable_mask
    )

def f_0(variables, t, r):
    tr = jnp.concatenate([t, r], axis=-1)
    u_0 = model.apply(variables, tr)
    return u_0, u_0


def f_t(variables, t, r):
    # derivative w.r.t. t (argument index 1)
    u_t, u_0 = jacrev(f_0, argnums=1, has_aux=True)(variables, t, r)
    return u_t, (u_t, u_0)


def f_r(variables, t, r):
    # spatial derivative w.r.t. r (argument index 2)
    u_r, u_0 = jacrev(f_0, argnums=2, has_aux=True)(variables, t, r)
    return u_r, (u_r, u_0)


def f_rr(variables, t, r):
    # second spatial derivative w.r.t. r
    u_rr, (u_r, u_0) = jacrev(f_r, argnums=2, has_aux=True)(variables, t, r)
    return u_rr, (u_rr, u_r, u_0)


# Function to compute the PDE equation value
def equation(t, r, u_0, u_t, u_r, u_rr):
    b, n = r.shape
    v1 = jnp.trace(u_rr, axis1=-1, axis2=-2)  # Equivalent to diagonal().sum()
    v2 = jnp.sum(u_r ** 2, axis=-1)

    r_i = r.reshape(b, n, 1)
    r_j = r.reshape(b, 1, n)
    r_ij = r_i - r_j
    coth_r_ij = 1. / jnp.tanh(r_ij / 2)
    coth_r_ij = jnp.nan_to_num(coth_r_ij, posinf=0, neginf=0)

    u_r_i = u_r.reshape(b, n, 1)
    u_r_j = u_r.reshape(b, 1, n)
    u_r_ij = u_r_i - u_r_j

    mask = jnp.triu(jnp.ones((n, n)), k=1)[None, :, :]

    v3 = 0.5 * jnp.sum(mask * coth_r_ij * u_r_ij, axis=(1, 2))
    eqn_value = u_t - (v1 + v2 + v3)
    return eqn_value


if args.boundary_func == 'log_G':
    def boundary(t, r):
        return log_G_func(t, r)
elif args.boundary_func == 'S0':
    def boundary(t, r):
        return f_S0_func(t, r)
else:
    raise ValueError(f"Unknown boundary_func: {args.boundary_func}")


def boundary_data_gen(B, key, r_step=(1.0, 3.0, 10.0), tmin=0.001, tmax=10.0, eps=1e-12):
    """
    Samples r with:
      - random direction ~ uniform on sphere (via normalized Gaussian)
      - radius chosen by:
          pick a band uniformly among [0,r1], [r1,r2], ..., [r_{m-1}, r_m]
          then sample radius uniformly within that band
    """
    key_dir, key_band, key_rad, key_t = jax.random.split(key, 4)

    # Direction
    v = jax.random.normal(key_dir, (B, n))
    v = v / (jnp.linalg.norm(v, axis=-1, keepdims=True) + eps)

    # Bands: [0, r_step[0]], [r_step[0], r_step[1]], ...
    r_step = jnp.asarray(r_step, dtype=v.dtype)
    bounds = jnp.concatenate([jnp.array([0.0], dtype=v.dtype), r_step])
    lows, highs = bounds[:-1], bounds[1:]
    nbands = lows.shape[0]

    band_idx = jax.random.randint(key_band, (B,), 0, nbands)
    low = lows[band_idx]
    high = highs[band_idx]

    u = jax.random.uniform(key_rad, (B,), minval=0.0, maxval=1.0)
    radius = low + u * (high - low)

    r = v * radius[:, None]
    t = jax.random.uniform(key_t, (B, 1), minval=tmin, maxval=tmax)
    return t, r


# def interior_data_gen(B, key, tmin=0.001, tmax=10, rmax=10):
#     key1, key2 = jax.random.split(key)
#     r = jax.random.normal(key1, (B, n)) * rmax
#     t = jax.random.uniform(key2, (B, 1), minval=tmin, maxval=tmax)
#     return t, r

def interior_data_gen(B, key, r_step=(1.0, 3.0, 10.0), tmin=0.001, tmax=10.0, eps=1e-12):
    """
    Samples r with:
      - random direction ~ uniform on sphere (via normalized Gaussian)
      - radius chosen by:
          pick a band uniformly among [0,r1], [r1,r2], ..., [r_{m-1}, r_m]
          then sample radius uniformly within that band
    """
    key_dir, key_band, key_rad, key_t = jax.random.split(key, 4)

    # Direction
    v = jax.random.normal(key_dir, (B, n))
    v = v / (jnp.linalg.norm(v, axis=-1, keepdims=True) + eps)

    # Bands: [0, r_step[0]], [r_step[0], r_step[1]], ...
    r_step = jnp.asarray(r_step, dtype=v.dtype)
    bounds = jnp.concatenate([jnp.array([0.0], dtype=v.dtype), r_step])
    lows, highs = bounds[:-1], bounds[1:]
    nbands = lows.shape[0]

    band_idx = jax.random.randint(key_band, (B,), 0, nbands)
    low = lows[band_idx]
    high = highs[band_idx]

    u = jax.random.uniform(key_rad, (B,), minval=0.0, maxval=1.0)
    radius = low + u * (high - low)

    r = v * radius[:, None]
    t = jax.random.uniform(key_t, (B, 1), minval=tmin, maxval=tmax)
    return t, r

def bucketize(t, boundaries):
    return jnp.searchsorted(boundaries, t, side='left')


def segment_mean(data, segment_ids, num_segments):
    sums = jax.ops.segment_sum(data, segment_ids, num_segments=num_segments)
    counts = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments=num_segments)
    return sums / (counts + 1.0e-10)


def grad_norm(grad):
    flattened_grad = jax.flatten_util.ravel_pytree(grad)[0]
    return jnp.linalg.norm(flattened_grad)


@jit
def compute_pde_loss(variables, t, r, t_bundlestep, t_weight):
    variables = stop_grad_frozen(variables)

    _, (u_t, u_0) = vmap(f_t, in_axes=(None, 0, 0))(variables, t, r)
    _, (u_rr, u_r, u_0) = vmap(f_rr, in_axes=(None, 0, 0))(variables, t, r)

    t, u_0, u_t, u_r, u_rr = t[:, 0], u_0[:, 0], u_t[:, 0, 0], u_r[:, 0], u_rr[:, 0]

    # PDE residual
    pde_value = equation(t, r, u_0, u_t, u_r, u_rr)

    # Bucketized loss
    t_bucket = bucketize(t, t_bundlestep)
    num_buckets = t_bundlestep.shape[0]
    loss_bucket = segment_mean(jnp.abs(pde_value), t_bucket, num_buckets)
    loss_pde = jnp.mean(t_weight * loss_bucket)
    loss_unweighted = jnp.mean(jnp.abs(pde_value))

    return loss_pde, loss_unweighted


@jit
def compute_boundary_loss(variables, t, r):
    variables = stop_grad_frozen(variables)

    _, (u_0) = vmap(f_0, in_axes=(None, 0, 0))(variables, t, r)

    u_0 = u_0

    # Boundary loss
    b = boundary(t, r)
    pred_b = u_0
    if normalize_boundary_loss:
        loss_b = jnp.mean(jnp.abs((pred_b - b) / (jnp.abs(b) + 1)))
    else:
        loss_b = jnp.mean(jnp.abs(pred_b - b))

    return loss_b


def loss_pde_only(variables, t, r):
    loss_pde, _ = compute_pde_loss(variables, t, r)
    return loss_pde


def loss_b_only(variables, t, r):
    loss_b = compute_boundary_loss(variables, t, r)
    return loss_b


def loss_func(variables, t, r, norm_pde, norm_b, weight_pde, weight_b):
    """
    norm_pde, norm_b: scalar normalizers (typically EMA of grad norms)
    """
    loss_pde, loss_unweighted = compute_pde_loss(variables, t, r)
    loss_b = compute_boundary_loss(variables, t[:n_boundary], r[:n_boundary])

    # normalize each loss by its (EMA) grad-norm
    loss_pde_scaled = loss_pde / norm_pde * weight_pde
    loss_b_scaled = loss_b / norm_b * weight_b

    total_loss = loss_pde_scaled + loss_b_scaled
    # aux keeps raw losses + unweighted pde loss
    return total_loss, (loss_pde, loss_b, loss_unweighted, loss_pde_scaled, loss_b_scaled)


@jit
def loss_and_grad(variables, t, r, t_bundlestep, t_weight, norm_pde, norm_b, weight_pde, weight_b):
    # 1) PDE loss and its gradient (w.r.t. variables)
    (loss_pde, loss_unweighted), grads_pde = value_and_grad(
        compute_pde_loss, has_aux=True
    )(variables, t, r, t_bundlestep, t_weight)

    # 2) Boundary loss and its gradient (only on boundary points)
    t_b = t[:n_boundary]
    r_b = r[:n_boundary]
    loss_b, grads_b = value_and_grad(compute_boundary_loss)(
        variables, t_b, r_b
    )

    # 3) Scale factors for each loss
    #    (norm_pde / norm_b are the EMA grad norms from previous steps)
    w_pde = weight_pde / norm_pde
    w_b   = weight_b   / norm_b

    loss_pde_scaled = w_pde * loss_pde
    loss_b_scaled   = w_b   * loss_b
    total_loss      = loss_pde_scaled + loss_b_scaled

    # 4) Combine per-loss grads to get grad of total_loss
    grads_total = tree_map(
        lambda g_pde, g_b: w_pde * g_pde + w_b * g_b,
        grads_pde, grads_b
    )

    grads_pde   = zero_out_frozen_grads(grads_pde)
    grads_b     = zero_out_frozen_grads(grads_b)
    grads_total = zero_out_frozen_grads(grads_total)

    # 5) Grad norms of each *unscaled* loss, for EMA update
    gnorm_pde = grad_norm(grads_pde)
    gnorm_b   = grad_norm(grads_b)

    return (
        total_loss,
        loss_pde,
        loss_b,
        loss_unweighted,
        loss_pde_scaled,
        loss_b_scaled,
        gnorm_pde,
        gnorm_b,
        grads_total,
    )

@jit
def loss_and_grad_b_only(variables, t, r):
    (loss_b), grad_b = value_and_grad(compute_boundary_loss)(variables, t, r)
    grad_b = zero_out_frozen_grads(grad_b)
    return loss_b, grad_b


def warmup_cosine_schedule(base_lr: float, warmup_steps: int, total_steps: int):
    """LR schedule: linear warmup from 0 -> base_lr over warmup_steps,
    then cosine decay from base_lr -> 0 over the remaining steps.
    Step is assumed to start at 0 (matches Optax internal count after init)."""
    base_lr = float(base_lr)
    total_steps = max(int(total_steps), 1)
    warmup_steps = max(int(warmup_steps), 0)
    warmup_steps = min(warmup_steps, total_steps)  # clamp for safety

    if warmup_steps == 0:
        denom = float(max(total_steps, 1))
        def schedule(step):
            step = jnp.asarray(step, dtype=jnp.float32)
            progress = jnp.clip(step / denom, 0.0, 1.0)
            return 0.5 * base_lr * (1.0 + jnp.cos(jnp.pi * progress))
        return schedule

    warm_denom = float(warmup_steps)
    cosine_denom = float(max(total_steps - warmup_steps, 1))

    def schedule(step):
        step = jnp.asarray(step, dtype=jnp.float32)

        # linear warmup
        warm_lr = base_lr * jnp.clip(step / warm_denom, 0.0, 1.0)

        # cosine decay (starts at base_lr when step==warmup_steps)
        progress = jnp.clip((step - warmup_steps) / cosine_denom, 0.0, 1.0)
        cos_lr = 0.5 * base_lr * (1.0 + jnp.cos(jnp.pi * progress))

        return jnp.where(step < warmup_steps, warm_lr, cos_lr)

    return schedule
wandb.init(
    project=project_name,
    config=vars(args),
    name=args.exp_name
)
key = jax.random.PRNGKey(args.seed)
start_iteration = 0
for phase in range(phases):
    print(f"Phase {phase+1}/{phases}")
    lr = lr_per_phase[phase]
    iterations = iterations_per_phase[phase]
    warmup_steps = warmup_per_phase[phase]
    t_max = t_max_per_phase[phase]
    print(f"  t_max = {t_max}, r_max = {r_max}")
    t_bundlestep = jnp.exp(jnp.linspace(jnp.log(t_0), jnp.log(t_max), t_bundle))
    t_weight = jnp.ones(t_bundle)


    schedule = warmup_cosine_schedule(
        base_lr=lr,
        warmup_steps=warmup_steps,
        total_steps=iterations,
    )
    base_optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clip),  # <-- gradient clipping
        optax.adam(schedule),                      # <-- then Adam w/ LR schedule
    )
    optimizer = optax.masked(base_optimizer, trainable_mask)
    opt_state = optimizer.init(variables)


    ema_pde_gnorm, ema_b_gnorm = None, None
    for i in tqdm(range(start_iteration, start_iteration + iterations)):
        step = i + 1
        local_step = i - start_iteration
        epoch = (step - 1) // log_freq
        if phase == 0:
            key, key_b, key_i = jax.random.split(key, 3)
            t_b, r_b = boundary_data_gen(
                n_data, key_b, r_step=r_step,tmin=t_0, tmax=t_boundary
            )
            (loss_b, grads) = loss_and_grad_b_only(
                variables, t_b, r_b
            )
            loss = loss_b
            loss_pde = jnp.array(0.0)
            loss_unweighted = jnp.array(0.0)
            loss_pde_scaled = jnp.array(0.0)
            loss_b_scaled = loss_b
            gnorm = grad_norm(grads)
            gnorm_pde = jnp.array(0.0)
            gnorm_b = gnorm
            # print(tree_map(lambda x:jnp.linalg.norm(x),grads))
            # if i==10:
            #     break
        else:

            # sample boundary & interior, then concatenate (boundary first, then interior)
            key, key_b, key_i = jax.random.split(key, 3)
            t_b, r_b = boundary_data_gen(
                n_boundary, key_b, r_step=r_step,tmin=t_0, tmax=t_boundary
                # n_boundary, key_b, tmin=t_0, tmax=t_boundary, rmax=r_boundary
            )
            t_i, r_i = interior_data_gen(
                # n_interior, key_i, tmin=t_0, tmax=t_max, rmax=r_max
                n_interior, key_i, r_step=r_step, tmin=t_0, tmax=t_max
            )

            t_batch = jnp.concatenate([t_b, t_i], axis=0)
            r_batch = jnp.concatenate([r_b, r_i], axis=0)

            # current normalizers (avoid division by zero; use 1.0 until EMA exists)
            norm_pde_host = ema_pde_gnorm if ema_pde_gnorm is not None else 1.0
            norm_b_host = ema_b_gnorm if ema_b_gnorm is not None else 1.0

            # ensure they are JAX scalars for jit
            norm_pde = jnp.array(norm_pde_host + 1e-8, dtype=jnp.float32)
            norm_b = jnp.array(norm_b_host + 1e-8, dtype=jnp.float32)

            (loss,
            loss_pde,
            loss_b,
            loss_unweighted,
            loss_pde_scaled,
            loss_b_scaled,
            gnorm_pde,
            gnorm_b,
            grads) = loss_and_grad(
                variables, t_batch, r_batch, t_bundlestep, t_weight, norm_pde, norm_b, weight_pde, weight_b
            )
            # grad norm for total (scaled) loss
            gnorm = grad_norm(grads)

        # optimizer step
        updates, opt_state = optimizer.update(grads, opt_state, params=variables)
        variables = optax.apply_updates(variables, updates)

        # move to host floats
        loss_f = float(loss)
        loss_pde_f = float(loss_pde)
        loss_b_f = float(loss_b)
        loss_unweighted_f = float(loss_unweighted)
        loss_pde_scaled_f = float(loss_pde_scaled)
        loss_b_scaled_f = float(loss_b_scaled)

        gnorm_f = float(gnorm)
        gnorm_pde_f = float(gnorm_pde)
        gnorm_b_f = float(gnorm_b)

        # update EMA of per-loss grad norms (host-side)
        if ema_pde_gnorm is None:
            ema_pde_gnorm = gnorm_pde_f
        else:
            ema_pde_gnorm = ema_beta * ema_pde_gnorm + (1.0 - ema_beta) * gnorm_pde_f

        if ema_b_gnorm is None:
            ema_b_gnorm = gnorm_b_f
        else:
            ema_b_gnorm = ema_beta * ema_b_gnorm + (1.0 - ema_beta) * gnorm_b_f

        current_lr = float(schedule(local_step))

        wandb.log(
            {
                "loss": loss_f,
                "loss_pde": loss_pde_f,
                "loss_boundary": loss_b_f,
                "loss_unweighted": loss_unweighted_f,
                "loss_pde_scaled": loss_pde_scaled_f,
                "loss_boundary_scaled": loss_b_scaled_f,
                "grad_norm": gnorm_f,
                "grad_norm_pde": gnorm_pde_f,
                "grad_norm_boundary": gnorm_b_f,
                "ema_grad_norm_pde": ema_pde_gnorm,
                "ema_grad_norm_boundary": ema_b_gnorm,
                "lr": current_lr,
                "step": step,
                "epoch": epoch,
            },
            step=step,
        )

        if step % log_freq == 0:
            params_host = jax.device_get(variables)
            ckpt_path = os.path.join(exp_dir, f"model_step_{step}.pkl")
            with open(ckpt_path, "wb") as f:
                pickle.dump(params_host, f)
    start_iteration += iterations

final_ckpt_path = os.path.join(exp_dir, f"model_final.pkl")
params_host = jax.device_get(variables)
with open(final_ckpt_path, "wb") as f:
    pickle.dump(params_host, f)
