# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass

import os
from types import SimpleNamespace
from datetime import datetime
import argparse
import json

import torch
import wandb

import jax
import jax.numpy as jnp
import optax
from flax import serialization

import sys
sys.path.append('e3_diffusion_for_molecules')

from configs.datasets_config import get_dataset_info
from qm9 import dataset as qm9_dataset
from qm9.models import get_model
from qm9.utils import prepare_context, compute_mean_mad

from equivariant_diffusion.utils import (
    remove_mean_with_mask,
    assert_mean_zero_with_mask,
    assert_correctly_masked,
    sample_center_gravity_zero_gaussian_with_mask,
)

from complete_graph_gnn_jax import MoleculePINN


# -------------------------
# Minimal EDM/QM9 config (replaces the big first argparse)
# -------------------------
def make_edm_cfg() -> SimpleNamespace:
    """
    Minimal config only for:
      - QM9 dataloader creation
      - building EDM model (for normalize + noise sampling)
      - preprocessing a (very large) batch into z_0
    """
    cfg = SimpleNamespace()

    # dataset / loader
    cfg.dataset = "qm9"
    cfg.datadir = "qm9/temp"
    cfg.num_workers = 0
    cfg.remove_h = False
    cfg.filter_n_atoms = None
    cfg.dequantization = "argmax_variational"
    cfg.include_charges = True

    # big-batch behavior (your original code tried to load "everything" at once)
    cfg.batch_size = 100000

    # conditioning (optional)
    cfg.conditioning = []

    # model selection
    cfg.model = "egnn_dynamics"
    cfg.probabilistic_model = "score"  # your argstring used score
    cfg.diffusion_steps = 1000
    cfg.diffusion_noise_schedule = "polynomial_2"
    cfg.diffusion_noise_precision = 1e-5
    cfg.diffusion_loss_type = "l2"
    cfg.sigma_max = 1.0

    # EGNN arch
    cfg.n_layers = 9
    cfg.nf = 256
    cfg.tanh = True
    cfg.attention = True
    cfg.norm_constant = 1.0
    cfg.sin_embedding = False
    cfg.sigma_weight = True
    cfg.normalization_factor = 1.0
    cfg.aggregation_method = "sum"
    cfg.normalize_factors = [1, 4, 10]

    # preprocessing aug (kept but default off)
    cfg.augment_noise = 0.0
    cfg.data_augmentation = False

    # other flags some repos expect
    cfg.no_cuda = False
    cfg.dp = True
    cfg.actnorm = True
    cfg.condition_time = True
    cfg.clip_grad = True
    cfg.hard_clip = 0.0
    cfg.trace = "hutch"
    cfg.ode_regularization = 1e-3

    return cfg


def pad(x: torch.Tensor, dim: int, to_pad: int) -> torch.Tensor:
    """Pads tensor x with zeros along `dim` to size `to_pad`."""
    x_shape = x.shape
    if x_shape[dim] == to_pad:
        return x
    x_template = x.new_zeros(x_shape[:dim] + (to_pad,) + x_shape[dim + 1 :])
    original_size = x_shape[dim]
    pad_slice = (slice(None),) * dim + (slice(0, original_size),) + (slice(None),) * (x.dim() - dim - 1)
    x_template[pad_slice] = x
    return x_template


def random_rotation(x: torch.Tensor) -> torch.Tensor:
    """
    Random SO(3) rotation per batch.
    x: (B, N, 3)
    """
    B = x.size(0)
    R = torch.randn(B, 3, 3, device=x.device, dtype=x.dtype)
    Q, _ = torch.linalg.qr(R)
    det = torch.det(Q)
    Q[det < 0, :, 0] *= -1
    return torch.einsum("bij,bnj->bni", Q, x)


def build_model_and_dataloaders(edm_cfg: SimpleNamespace):
    edm_cfg.cuda = (not edm_cfg.no_cuda) and torch.cuda.is_available()
    device = torch.device("cuda" if edm_cfg.cuda else "cpu")
    dtype = torch.float32

    dataset_info = get_dataset_info(edm_cfg.dataset, edm_cfg.remove_h)

    dataloaders, _charge_scale = qm9_dataset.retrieve_dataloaders(edm_cfg)
    data_dummy = next(iter(dataloaders["train"]))

    # conditioning support (optional)
    if len(edm_cfg.conditioning) > 0:
        property_norms = compute_mean_mad(dataloaders, edm_cfg.conditioning, edm_cfg.dataset)
        context_dummy = prepare_context(edm_cfg.conditioning, data_dummy, property_norms)
        edm_cfg.context_node_nf = context_dummy.size(2)
    else:
        edm_cfg.context_node_nf = 0
        property_norms = None

    # node count stats used by get_model in some repos
    num_atoms = dataloaders["train"].dataset.data["num_atoms"]
    n_count = {i: (num_atoms == i).sum().item() for i in range(1, num_atoms.max() + 1)}
    dataset_info["n_nodes"] = n_count
    dataset_info["max_n_nodes"] = num_atoms.max().item()

    model, _nodes_dist, prop_dist = get_model(edm_cfg, device, dataset_info, dataloaders["train"])
    if prop_dist is not None and property_norms is not None:
        prop_dist.set_normalizer(property_norms)
    model = model.to(device)

    return model, dataloaders, dataset_info, device, dtype, property_norms


def preprocess_batch(
    data: dict,
    *,
    model,
    edm_cfg: SimpleNamespace,
    device: torch.device,
    dtype: torch.dtype,
    n_max: int,
    property_norms,
):
    """
    Returns:
      z_0: (B, n_max, D)
      node_mask: (B, n_max, 1)
      n_nodes: (B,) int32
    """
    x = data["positions"].to(device, dtype)  # (B,N,3)
    node_mask = data["atom_mask"].to(device, dtype).unsqueeze(2)  # (B,N,1)
    one_hot = data["one_hot"].to(device, dtype)  # (B,N,C)

    if edm_cfg.include_charges:
        charges = data["charges"].to(device, dtype)  # typically (B,N,1)
    else:
        # keep shape compatible for normalize/concat
        charges = torch.zeros((x.size(0), x.size(1), 1), device=device, dtype=dtype)

    # pad to n_max
    x = pad(x, 1, n_max)
    node_mask = pad(node_mask, 1, n_max)
    one_hot = pad(one_hot, 1, n_max)
    charges = pad(charges, 1, n_max)

    # center + augment
    x = remove_mean_with_mask(x, node_mask)
    if edm_cfg.augment_noise > 0:
        eps = sample_center_gravity_zero_gaussian_with_mask(x.size(), x.device, node_mask)
        x = x + eps * edm_cfg.augment_noise
        x = remove_mean_with_mask(x, node_mask)

    if edm_cfg.data_augmentation:
        x = random_rotation(x).detach()

    assert_correctly_masked(x, node_mask)
    assert_correctly_masked(one_hot, node_mask)
    assert_correctly_masked(charges, node_mask)
    assert_mean_zero_with_mask(x, node_mask)

    h = {"categorical": one_hot, "integer": charges}

    # (optional) conditioning context, if your model uses it internally
    if len(edm_cfg.conditioning) > 0:
        context = prepare_context(edm_cfg.conditioning, data, property_norms).to(device, dtype)
        assert_correctly_masked(context, node_mask)

    x_norm, h_norm, _delta_log_px = model.normalize(x, h, node_mask)

    z_0 = torch.cat([x_norm, h_norm["categorical"], h_norm["integer"]], dim=2)
    n_nodes = node_mask.squeeze(-1).sum(dim=1).to(torch.int32)  # (B,)

    return z_0, node_mask, n_nodes


# -------------------------
# PINN training config (keep argparse here)
# -------------------------
def parse_pinn_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="MoleculePINN training")

    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--weight_decay", type=float, default=0.0)
    p.add_argument("--ema_beta", type=float, default=0.99)
    p.add_argument("--ema_eps", type=float, default=1e-8)

    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--n_boundary", type=int, default=1)

    p.add_argument("--num_epochs", type=int, default=100)
    p.add_argument("--log_every", type=int, default=50)
    p.add_argument("--max_iters_per_epoch", type=int, default=1000)

    p.add_argument("--exp_name", type=str, default="tmp")
    p.add_argument("--t_max_schedule", type=float, nargs="+", default=[0, 0.1, 0.3, 0.5])
    p.add_argument("--epoch_schedule", type=int, nargs="+", default=[20, 50, 50, 50])
    p.add_argument("--t_min", type=float, default=0.01)

    p.add_argument("--n_fourier", type=int, default=32)
    p.add_argument("--r_fourier_min", type=float, default=0.1)
    p.add_argument("--r_fourier_max", type=float, default=1.0)
    p.add_argument("--t_fourier_min", type=float, default=0.1)
    p.add_argument("--t_fourier_max", type=float, default=1.0)
    p.add_argument("--apply_log", action="store_true", default=False)

    p.add_argument("--sigma_scale_factor", type=float, default=3.0)
    p.add_argument("--sigma_scale_method", type=str, default="constant")  # constant|adaptive
    p.add_argument("--sigma_scale_adaptive", type=float, default=0.3)
    p.add_argument("--normalize_by_u_t", action="store_true", default=False)

    return p.parse_args()


# -------------------------
# Torch -> JAX dataset wrappers
# -------------------------
def t_to_sigma(t: jax.Array) -> jax.Array:
    return jnp.sqrt(2.0 * t)


def get_eps_from_edm(model, z_0_torch: torch.Tensor, node_mask_torch: torch.Tensor) -> torch.Tensor:
    # sample eps using the EDM model helper (kept from your original)
    eps = model.sample_combined_position_feature_noise(
        n_samples=z_0_torch.size(0),
        n_nodes=z_0_torch.size(1),
        node_mask=node_mask_torch,
    )
    return eps


class PINNDataset(torch.utils.data.Dataset):
    def __init__(self, model, z_0: torch.Tensor, node_mask: torch.Tensor, n_nodes: torch.Tensor):
        self.model = model
        self.z_0 = z_0
        self.node_mask = node_mask
        self.n_nodes = n_nodes
        self.length = z_0.size(0)
        self.recompute_eps()

    def recompute_eps(self):
        self.eps = get_eps_from_edm(self.model, self.z_0, self.node_mask)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.z_0[idx], self.eps[idx], self.n_nodes[idx]


def jnp_collate(batch):
    z_0_batch = jnp.stack([item[0].detach().cpu().numpy() for item in batch], axis=0)
    eps_batch = jnp.stack([item[1].detach().cpu().numpy() for item in batch], axis=0)
    n_nodes_batch = jnp.array([int(item[2].detach().cpu().item()) for item in batch], dtype=jnp.int32)
    return z_0_batch, eps_batch, n_nodes_batch


# -------------------------
# PINN model + losses
# -------------------------
def build_pinn(pinn_args: argparse.Namespace, n_max: int):
    key = jax.random.PRNGKey(0)
    pinn = MoleculePINN(
        n_max=n_max,
        n_fourier=pinn_args.n_fourier,
        r_fourier_min=pinn_args.r_fourier_min,
        r_fourier_max=pinn_args.r_fourier_max,
        t_fourier_min=pinn_args.t_fourier_min,
        t_fourier_max=pinn_args.t_fourier_max,
        apply_log=pinn_args.apply_log,
    )

    x0 = jnp.zeros((n_max, 3), dtype=jnp.float32)
    y0 = jnp.zeros((n_max, 3), dtype=jnp.float32)
    t0 = jnp.array(0.0, dtype=jnp.float32)
    n0 = jnp.array(10, dtype=jnp.int32)

    variables = pinn.init(key, x0, y0, t0, n0)
    return pinn, variables["params"], key


def boundary_func(x, y, t, n):
    Nmax = x.shape[0]
    idx = jnp.arange(Nmax)
    mask = (idx < n).astype(x.dtype)[:, None]  # (Nmax,1)
    r2 = jnp.sum(((x - y) ** 2) * mask)
    return -r2 / (4.0 * t)


def make_losses(pinn: MoleculePINN, pinn_args: argparse.Namespace):
    def f_0(params, x, y, t, n):
        u_0 = pinn.apply({"params": params}, x, y, t, n)
        return u_0, u_0

    def pde_loss(params, x, y, t, n):
        # du/dt
        u_t, (u_t_aux, u_0_aux) = jax.jacrev(f_0, argnums=3, has_aux=True)(params, x, y, t, n)

        # du/dx and d2u/dx2 (w.r.t. y arg)
        def f_x(params_, x_, y_, t_, n_):
            u_x, u_0 = jax.jacrev(f_0, argnums=2, has_aux=True)(params_, x_, y_, t_, n_)
            return u_x, (u_x, u_0)

        u_xx, (u_x, u_0) = jax.jacrev(f_x, argnums=2, has_aux=True)(params, x, y, t, n)

        u_0 = jnp.nan_to_num(u_0)
        u_t = jnp.nan_to_num(u_t)
        u_x = jnp.nan_to_num(u_x)
        u_xx = jnp.nan_to_num(u_xx)

        term1 = jnp.sum(jnp.reshape(u_x, (-1,)) ** 2)
        n_k = u_x.size
        term2 = jnp.trace(jnp.reshape(u_xx, (n_k, n_k)))

        value = jnp.abs(u_t - term1 - term2)
        if pinn_args.normalize_by_u_t:
            denom = jnp.maximum(jax.lax.stop_gradient(jnp.abs(u_t)), 1.0)
            value = value / denom
        return value

    def boundary_loss(params, x, y, t, n):
        u_0, _ = f_0(params, x, y, t, n)
        u_b = boundary_func(x, y, t, n)
        u_0 = jnp.nan_to_num(u_0)
        u_b = jnp.nan_to_num(u_b)
        return jnp.mean((u_0 - u_b) ** 2)

    batched_pde_loss = jax.vmap(pde_loss, in_axes=(None, 0, 0, 0, 0))
    batched_boundary_loss = jax.vmap(boundary_loss, in_axes=(None, 0, 0, 0, 0))

    def pde_loss_batch(params, X, Y, T, N):
        return jnp.mean(batched_pde_loss(params, X, Y, T, N))

    def boundary_loss_batch(params, X, Y, T, N, n_boundary: int):
        Xb, Yb, Tb, Nb = X[:n_boundary], Y[:n_boundary], T[:n_boundary], N[:n_boundary]
        return jnp.mean(batched_boundary_loss(params, Xb, Yb, Tb, Nb))

    return pde_loss_batch, boundary_loss_batch


def pytree_l2_norm(tree) -> jax.Array:
    leaves = jax.tree_util.tree_leaves(tree)
    return jnp.sqrt(sum(jnp.sum(jnp.square(x)) for x in leaves))


def pytree_add_scaled(g1, g2, a: float, b: float):
    return jax.tree_util.tree_map(lambda x, y: a * x + b * y, g1, g2)


def weights_from_ema(ema_pde, ema_b, eps=1e-8):
    inv_pde = 1.0 / (ema_pde + eps)
    inv_b = 1.0 / (ema_b + eps)
    s = inv_pde + inv_b
    return inv_pde / s, inv_b / s


def sample_t(key, batch_size: int, t_min: float, t_max: float):
    if float(t_max) <= float(t_min):
        return jnp.full((batch_size,), t_min, dtype=jnp.float32)
    return jax.random.uniform(key, shape=(batch_size,), minval=t_min, maxval=t_max)


def prepare_noisy(key, z_0, eps, t, pinn_args: argparse.Namespace):
    sigma = t_to_sigma(t)
    if pinn_args.sigma_scale_method == "constant":
        sigma_scaled = jnp.minimum(sigma * pinn_args.sigma_scale_factor, 1.0)
    elif pinn_args.sigma_scale_method == "adaptive":
        minval = 1.0 - pinn_args.sigma_scale_adaptive
        maxval = 1.0 + pinn_args.sigma_scale_adaptive
        rand = jax.random.uniform(key, shape=sigma.shape, minval=minval, maxval=maxval)
        sigma_scaled = sigma * rand
    else:
        raise ValueError(f"Unknown sigma_scale_method: {pinn_args.sigma_scale_method}")
    return z_0 + sigma_scaled[:, None, None] * eps


# -------------------------
# Checkpoint helpers
# -------------------------
def save_checkpoint(path, params, opt_state, ema_pde_gn, ema_b_gn, global_step, key):
    ckpt = {
        "params": params,
        "opt_state": opt_state,
        "ema_pde_gn": ema_pde_gn,
        "ema_b_gn": ema_b_gn,
        "global_step": jnp.array(global_step, dtype=jnp.int32),
        "key": key,
    }
    bytes_ = serialization.to_bytes(ckpt)
    tmp_path = path + ".tmp"
    with open(tmp_path, "wb") as f:
        f.write(bytes_)
    os.replace(tmp_path, path)


# -------------------------
# Main
# -------------------------
def main():
    n_max = 29

    # 1) Build EDM model + load a huge train/val batch -> z_0
    edm_cfg = make_edm_cfg()
    model, dataloaders, _dataset_info, device, dtype, property_norms = build_model_and_dataloaders(edm_cfg)

    train_data = next(iter(dataloaders["train"]))
    val_data = next(iter(dataloaders["valid"]))

    z_0_train, node_mask_train, n_nodes_train = preprocess_batch(
        train_data, model=model, edm_cfg=edm_cfg, device=device, dtype=dtype, n_max=n_max, property_norms=property_norms
    )
    z_0_val, node_mask_val, n_nodes_val = preprocess_batch(
        val_data, model=model, edm_cfg=edm_cfg, device=device, dtype=dtype, n_max=n_max, property_norms=property_norms
    )

    # 2) PINN args
    pinn_args = parse_pinn_args()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_name = f"{pinn_args.exp_name}_{timestamp}"

    project_name = "e3_diffusion_pinn"
    ckpt_dir = os.path.join(project_name, exp_name)
    os.makedirs(ckpt_dir, exist_ok=True)
    with open(os.path.join(ckpt_dir, "pinn_args.json"), "w") as f:
        json.dump(vars(pinn_args), f, indent=2, sort_keys=True)

    # 3) Build PINN + losses
    pinn, params, key = build_pinn(pinn_args, n_max=n_max)
    pde_loss_batch, boundary_loss_batch = make_losses(pinn, pinn_args)

    pde_loss_batch_jit = jax.jit(pde_loss_batch)
    boundary_loss_batch_jit = jax.jit(boundary_loss_batch, static_argnames=("n_boundary",))

    def compute_grads(params, X, Y, T, N, n_boundary: int, w_pde: float, w_b: float):
        pde_val, pde_grad = jax.value_and_grad(pde_loss_batch)(params, X, Y, T, N)
        b_val, b_grad = jax.value_and_grad(boundary_loss_batch)(params, X, Y, T, N, n_boundary)

        total_val = w_pde * pde_val + w_b * b_val
        total_grad = pytree_add_scaled(pde_grad, b_grad, w_pde, w_b)

        return (
            total_val,
            total_grad,
            pytree_l2_norm(total_grad),
            pde_val,
            pytree_l2_norm(pde_grad),
            b_val,
            pytree_l2_norm(b_grad),
        )

    compute_grads_jit = jax.jit(compute_grads, static_argnames=("n_boundary",))

    def compute_grads_bonly(params, X, Y, T, N, n_boundary: int):
        b_val, b_grad = jax.value_and_grad(boundary_loss_batch)(params, X, Y, T, N, n_boundary)
        total_val = b_val
        total_grad = b_grad
        return (
            total_val,
            total_grad,
            pytree_l2_norm(total_grad),
            jnp.array(0.0),
            jnp.array(0.0),
            b_val,
            pytree_l2_norm(b_grad),
        )

    compute_grads_bonly_jit = jax.jit(compute_grads_bonly, static_argnames=("n_boundary",))

    # 4) Torch->JAX loaders
    train_ds = PINNDataset(model, z_0_train, node_mask_train, n_nodes_train)
    val_ds = PINNDataset(model, z_0_val, node_mask_val, n_nodes_val)

    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=pinn_args.batch_size, shuffle=True, drop_last=True, collate_fn=jnp_collate
    )
    val_loader = torch.utils.data.DataLoader(
        val_ds, batch_size=pinn_args.batch_size, shuffle=False, drop_last=True, collate_fn=jnp_collate
    )

    # 5) wandb + optimizer
    wandb.init(project=project_name, name=exp_name, config=vars(pinn_args))
    optimizer = optax.adamw(learning_rate=pinn_args.lr, weight_decay=pinn_args.weight_decay)
    opt_state = optimizer.init(params)

    ema_pde_gn = jnp.array(1.0, dtype=jnp.float32)
    ema_b_gn = jnp.array(1.0, dtype=jnp.float32)

    global_step = 0
    best_val = float("inf")
    n_boundary_cfg = int(pinn_args.n_boundary)

    # 6) training
    epoch = 0
    for i_curriculum, (t_max, n_epochs) in enumerate(zip(pinn_args.t_max_schedule, pinn_args.epoch_schedule)):
        for _ in range(n_epochs):
            train_ds.recompute_eps()
            val_ds.recompute_eps()

            # ---- TRAIN ----
            train_sum = {k: 0.0 for k in [
                "loss_total", "loss_pde", "loss_boundary",
                "gn_total", "gn_pde", "gn_boundary",
                "ema_gn_pde", "ema_gn_boundary",
                "w_pde", "w_boundary",
            ]}
            train_count = 0

            for i_iter, (z_0, eps, n_nodes) in enumerate(train_loader):
                key, subkey = jax.random.split(key)
                B = z_0.shape[0]

                if i_curriculum == 0:
                    # boundary-only
                    t = jnp.full((B,), pinn_args.t_min, dtype=jnp.float32)
                    z_noisy = prepare_noisy(subkey, z_0, eps, t, pinn_args)
                    n_boundary_curr = B

                    w_pde, w_b = 1.0, 1.0
                    total_val, total_grad, total_gn, pde_val, pde_gn, b_val, b_gn = compute_grads_bonly_jit(
                        params, z_noisy, z_0, t, n_nodes, n_boundary_curr
                    )
                    pde_val = 0.0
                    pde_gn = 0.0
                else:
                    # PDE + boundary
                    t = sample_t(subkey, B, pinn_args.t_min, t_max)
                    t = t.at[:n_boundary_cfg].set(pinn_args.t_min)

                    z_noisy = prepare_noisy(subkey, z_0, eps, t, pinn_args)

                    w_pde, w_b = weights_from_ema(ema_pde_gn, ema_b_gn, eps=pinn_args.ema_eps)
                    w_pde = jax.lax.stop_gradient(w_pde)
                    w_b = jax.lax.stop_gradient(w_b)

                    total_val, total_grad, total_gn, pde_val, pde_gn, b_val, b_gn = compute_grads_jit(
                        params, z_noisy, z_0, t, n_nodes, n_boundary_cfg, w_pde, w_b
                    )

                    ema_pde_gn = pinn_args.ema_beta * ema_pde_gn + (1.0 - pinn_args.ema_beta) * pde_gn
                    ema_b_gn = pinn_args.ema_beta * ema_b_gn + (1.0 - pinn_args.ema_beta) * b_gn

                updates, opt_state = optimizer.update(total_grad, opt_state, params=params)
                params = optax.apply_updates(params, updates)
                global_step += 1

                # accumulate
                train_sum["loss_total"] += float(total_val)
                train_sum["loss_pde"] += float(pde_val)
                train_sum["loss_boundary"] += float(b_val)
                train_sum["gn_total"] += float(total_gn)
                train_sum["gn_pde"] += float(pde_gn)
                train_sum["gn_boundary"] += float(b_gn)
                train_sum["ema_gn_pde"] += float(ema_pde_gn)
                train_sum["ema_gn_boundary"] += float(ema_b_gn)
                train_sum["w_pde"] += float(w_pde)
                train_sum["w_boundary"] += float(w_b)
                train_count += 1

                if global_step % pinn_args.log_every == 0 and train_count > 0:
                    denom = float(train_count)
                    metrics = {f"train/{k}": train_sum[k] / denom for k in train_sum}
                    metrics["epoch"] = epoch
                    wandb.log(metrics, step=global_step)
                    print(
                        f"[epoch {epoch:03d} step {global_step:06d}] "
                        f"loss={metrics['train/loss_total']:.6g} "
                        f"pde={metrics['train/loss_pde']:.6g} "
                        f"b={metrics['train/loss_boundary']:.6g} "
                        f"w_pde={metrics['train/w_pde']:.3f} w_b={metrics['train/w_boundary']:.3f}"
                    )
                    for k in train_sum:
                        train_sum[k] = 0.0
                    train_count = 0

                if (i_iter + 1) == pinn_args.max_iters_per_epoch:
                    save_checkpoint(os.path.join(ckpt_dir, f"latest_step{global_step:08d}.ckpt"),
                                   params, opt_state, ema_pde_gn, ema_b_gn, global_step, key)
                    break

            # ---- VAL ----
            val_total = 0.0
            val_pde = 0.0
            val_b = 0.0
            val_count = 0

            for (z_0, eps, n_nodes) in val_loader:
                key, subkey = jax.random.split(key)
                B = z_0.shape[0]

                if i_curriculum == 0:
                    t = jnp.full((B,), pinn_args.t_min, dtype=jnp.float32)
                    z_noisy = prepare_noisy(subkey, z_0, eps, t, pinn_args)
                    b = float(boundary_loss_batch_jit(params, z_noisy, z_0, t, n_nodes, B))
                    p = 0.0
                    tot = b
                else:
                    t = sample_t(subkey, B, pinn_args.t_min, t_max)
                    t = t.at[:n_boundary_cfg].set(pinn_args.t_min)
                    z_noisy = prepare_noisy(subkey, z_0, eps, t, pinn_args)

                    p = float(pde_loss_batch_jit(params, z_noisy, z_0, t, n_nodes))
                    b = float(boundary_loss_batch_jit(params, z_noisy, z_0, t, n_nodes, n_boundary_cfg))
                    tot = p + b

                val_total += tot
                val_pde += p
                val_b += b
                val_count += 1

            val_total /= max(val_count, 1)
            val_pde /= max(val_count, 1)
            val_b /= max(val_count, 1)

            print(f"[epoch {epoch:03d} VAL] loss={val_total:.6g} pde={val_pde:.6g} b={val_b:.6g}")
            wandb.log(
                {"val/loss_total": val_total, "val/loss_pde": val_pde, "val/loss_boundary": val_b, "epoch": epoch},
                step=global_step,
            )

            if val_total < best_val:
                best_val = val_total
                save_checkpoint(os.path.join(ckpt_dir, "best.ckpt"),
                               params, opt_state, ema_pde_gn, ema_b_gn, global_step, key)

            epoch += 1

    save_checkpoint(os.path.join(ckpt_dir, "final.ckpt"),
                   params, opt_state, ema_pde_gn, ema_b_gn, global_step, key)
    wandb.finish()


if __name__ == "__main__":
    main()
