"""
Student‑network training script (DDP) — DKAN Version
----------------------------------------------------
This version uses DKAN layers instead of standard MLP layers for the student
network, trained on the methane dataset. Includes an epoch-based training
schedule with warmup, MLP, DKAN turn-on, Frobenius decay, and LR decay phases.
"""

from __future__ import annotations

import argparse
import os
import shutil
import time
from typing import Generator, Tuple, Dict, Union, Iterable

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import yaml
from itertools import permutations
from scipy.spatial.transform import Rotation  # random rotations
from torch.nn.parallel import DistributedDataParallel as DDP
from joblib import Parallel, delayed
from functools import lru_cache
from displacements_invariants import compute_invariants_displacements_wrapper
from ldKAN.dkan_2d import DKAN_2D_Layer  # DKAN Layer import

# ---------------------------------------------------------------------------
# Constant‑time permutation lookup (24 possible permutations of 4 elements)
# ---------------------------------------------------------------------------

# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class FeedForwardNet(nn.Module):
    def __init__(
        self,
        actual_input_dim: int,
        padded_input_dim: int,
        hidden_dim: int,
        n_inner_layers: int,
        padded_output_dim: int,
        n_chunks: int,
        block_size_forward: int,
        block_size_backward: int,
        tile_size_forward: int,
        tile_size_backward: int,
        init_scale: float,
    ):
        super().__init__()
        self.input_bn = nn.BatchNorm1d(actual_input_dim, affine=False)
        self.padded_input_dim = padded_input_dim
        self.fcs, self.bns = nn.ModuleList(), nn.ModuleList()

        # First DKAN layer
        self.fcs.append(
            DKAN_2D_Layer(
                n_chunks,
                padded_input_dim,
                hidden_dim,
                block_size_forward,
                block_size_backward,
                tile_size_forward,
                tile_size_backward,
                False, False, True, False, init_scale, True
            )
        )
        self.bns.append(nn.BatchNorm1d(hidden_dim, affine=False))

        # Inner DKAN layers
        for _ in range(n_inner_layers):
            self.fcs.append(
                DKAN_2D_Layer(
                    n_chunks,
                    hidden_dim,
                    hidden_dim,
                    block_size_forward,
                    block_size_backward,
                    tile_size_forward,
                    tile_size_backward,
                    False, False, True, False, init_scale, True
                )
            )
            self.bns.append(nn.BatchNorm1d(hidden_dim, affine=False))

        # Final DKAN layer
        self.fcs.append(
            DKAN_2D_Layer(
                n_chunks,
                hidden_dim,
                padded_output_dim,
                block_size_forward,
                block_size_backward,
                tile_size_forward,
                tile_size_backward,
                False, False, True, False, init_scale, True
            )
        )

    def forward(self, x: torch.Tensor, weight_dkan: float) -> torch.Tensor:
        # Input x shape: (B, actual_input_dim)
        x = self.input_bn(x)  # (B, actual_input_dim)

        # Pad input to be divisible by tile_size_forward
        pad_size = self.padded_input_dim - x.shape[1]
        if pad_size > 0:
            x = F.pad(x, (0, pad_size))  # (B, padded_input_dim)

        # Transpose for DKAN (expects batch-last)
        x = x.transpose(0, 1).contiguous()  # (padded_input_dim, B)

        '''# First DKAN layer
        x = self.fcs[0](x, weight_dkan, True)  # (hidden_dim, B)

        # Inner layers
        for i in range(len(self.fcs) - 2):  # Loop up to second-to-last DKAN
            # Transpose for BatchNorm (expects batch-first)
            x = x.transpose(0, 1)  # (B, hidden_dim)
            x = self.bns[i](x)    # Apply ith BN layer
            # Transpose back for DKAN
            x = x.transpose(0, 1)  # (hidden_dim, B)
            # Apply (i+1)th DKAN layer
            x = self.fcs[i+1](x, weight_dkan, True)  # (hidden_dim, B)

        # Final DKAN layer (no subsequent BN or activation needed by DKAN)
        x = self.fcs[-1](x, weight_dkan, False)  # (padded_output_dim, B)
        '''

        for i in range(len(self.fcs) - 1):
            x = self.fcs[i](x, weight_dkan, True)
            x = x.transpose(0, 1)
            x = self.bns[i](x)
            x = x.transpose(0, 1)

        x = self.fcs[-1](x, weight_dkan, False)

        # Transpose back to batch-first
        x = x.transpose(0, 1).contiguous()  # (B, padded_output_dim)

        # Slice to get the single output value
        x = x[:, 0]  # (B,)

        return x

    def get_frobenius_regularization(self) -> torch.Tensor:
        reg = torch.tensor(0.0, device=next(self.parameters()).device)
        for fc in self.fcs:
            reg += fc.get_frobenius_regularization()
        return reg


# ---------------------------------------------------------------------------
# Mini‑batch generator / augmentation
# ---------------------------------------------------------------------------

def _make_rotations(num: int) -> np.ndarray:
    """Generate *num* random 3×3 rotation matrices (half improper), float32."""
    R = Rotation.random(num).as_matrix().astype(np.float32)  # (num, 3, 3)
    R[np.random.rand(num) >= 0.5] *= -1  # flip half to improper rotations
    return R

@lru_cache(maxsize=None)
def _perm_table(device: torch.device) -> torch.Tensor:
    return torch.tensor(list(permutations(range(4))),
                        dtype=torch.long, device=device)  # (24,4)

def iterate_minibatches(
    X: np.ndarray,
    y: np.ndarray,
    batch_size: int,
    shuffle: bool,
    task: str,
    mode: str,
    device: Union[str, torch.device] = "cpu",
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:

    device = torch.device(device)

    # -------------------------------------------------------------------------
    # constant permutation table (24×4) – keep a single copy on the chosen device
    # -------------------------------------------------------------------------
    perm_tbl = _perm_table(device)

    # -------------------------------------------------------------------------
    # pre‑generate a pool of random proper/improper rotations, then move to device
    # -------------------------------------------------------------------------
    N_ROT_WORKERS = min(20, os.cpu_count() or 1)
    CHUNK_SIZE = 250_000

    n_samples = X.shape[0]
    order = np.arange(n_samples)
    if shuffle:
        np.random.shuffle(order)

    if (mode == "train") and (task in {"HYDROGEN_DISPLACEMENTS", "DISPLACEMENTS_INVARIANTS"}):
        n_chunks = (n_samples + CHUNK_SIZE - 1) // CHUNK_SIZE
        chunk_sizes: Iterable[int] = (
            [CHUNK_SIZE] * (n_chunks - 1) + [n_samples - CHUNK_SIZE * (n_chunks - 1)]
        )

        rot_pool_np = np.concatenate(
            Parallel(n_jobs=min(N_ROT_WORKERS, n_chunks), backend="loky")(
                delayed(_make_rotations)(m) for m in chunk_sizes
            ),
            axis=0,
        )  # (N, 3, 3)

        rot_pool = torch.tensor(rot_pool_np, device=device)  # (N, 3, 3)
        del rot_pool_np  # free host RAM
    else:
        rot_pool = None  # not used

    # -------------------------------------------------------------------------
    # iterate over batches
    # -------------------------------------------------------------------------
    for start in range(0, n_samples, batch_size):
        idx = order[start : start + batch_size]

        # NumPy → Torch, host → device
        xb = torch.as_tensor(X[idx], dtype=torch.float32, device=device)  # (B, …)
        yb = torch.as_tensor(y[idx], dtype=torch.float32, device=device)

        # ---------------------------------------------------------------------
        # task‑specific augmentation / selection
        # ---------------------------------------------------------------------
        if task == "INVARIANTS_DISTANCES":
            # nothing to do — tensors already on device
            pass

        elif task == "DISTANCES":
            if mode == "train":
                sel = torch.randint(0, 24, (xb.shape[0],), device=device)
                xb = xb[torch.arange(xb.shape[0], device=device), sel, :]  # (B, 10)
            else:
                xb = xb[:, 0, :]  # (B, 10)

        elif task == "HYDROGEN_DISPLACEMENTS":
            # xb: (B, 4, 3)
            if mode == "train":
                B = xb.shape[0]

                # --- 24‑way permutation ------------------------------------------------
                perm_idx = torch.randint(0, 24, (B,), device=device)  # (B,)
                xb = xb[
                    torch.arange(B, device=device).unsqueeze(1),  # (B, 1)
                    perm_tbl[perm_idx],  # (B, 4)
                    :,
                ]  # -> (B, 4, 3)

                # --- pre‑generated rotation -------------------------------------------
                idx_t = torch.as_tensor(idx, device=device)
                Rb = rot_pool[idx_t]  # (B, 3, 3)
                xb = torch.matmul(Rb, xb.transpose(1, 2)).transpose(1, 2)  # (B, 4, 3)

            # validation/test: no permutation, no rotation
            xb = xb.reshape(xb.shape[0], -1)  # (B, 12)
        elif task == "DISPLACEMENTS_INVARIANTS":
            # xb : (B, 4, 3) tensor already on device
            if mode == "train":
                # ----- random rotation only (no permutation) -------------------
                idx_t = torch.as_tensor(idx, device=device)
                Rb    = rot_pool[idx_t]                    # (B,3,3)
                xb    = torch.matmul(Rb, xb.transpose(1,2)).transpose(1,2)  # (B,4,3)

            # compute 34 invariant polynomials (GPU‑friendly)
            xb = compute_invariants_displacements_wrapper(xb)   # (B, 34)
        else:
            raise ValueError(f"Unsupported TASK '{task}'")

        yield xb, yb

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def cpu_clone_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}


# --------------------------------------------------------------------------- #
# Training (per rank)
# --------------------------------------------------------------------------- #
def ddp_train(args: argparse.Namespace) -> None:
    rank       = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    dist.init_process_group("nccl", init_method="env://",
                            world_size=world_size, rank=rank)
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    # ---------------- config ----------------
    with open(args.calc_specification_path) as f:
        cfg = yaml.safe_load(f)

    required_keys = [
        "TASK", "BATCH_SIZE", "HIDDEN_DIM", "N_INNER_LAYERS", "RECORD_INTERVAL",
        "N_CHUNKS", "BLOCK_SIZE_FORWARD", "BLOCK_SIZE_BACKWARD",
        "TILE_SIZE_FORWARD", "TILE_SIZE_BACKWARD", "INIT_SCALE",
        "WARMUP_EPOCHS", "PURE_MLP_EPOCHS", "PURE_MLP_LR",
        "DKAN_TURN_ON_EPOCHS", "DKAN_TURN_ON_SCALE", "DKAN_TURN_ON_CAP",
        "DKAN_FROBENIUS_DECAY_EPOCHS", "DKAN_FROBENIUS_DECAY_SCALE",
        "FROBENIUS_WEIGHT_CAP", "DKAN_LEARNING_RATE_DECAY_EPOCHS",
        "DKAN_LEARNING_RATE_DECAY_SCALE", "INITIAL_FROBENIUS_WEIGHT",
        "DKAN_BASE_LR"
    ]
    for key in required_keys:
        if key not in cfg:
            raise KeyError(f"Missing required configuration parameter: {key}")

    TASK       = cfg["TASK"]
    SUPPORTED  = {"INVARIANTS_DISTANCES", "DISTANCES", "HYDROGEN_DISPLACEMENTS", "DISPLACEMENTS_INVARIANTS"}
    if TASK not in SUPPORTED:
        raise ValueError(f"Unsupported TASK '{TASK}'. Choose from {SUPPORTED}")

    BATCH_SIZE       = int(cfg["BATCH_SIZE"])
    HIDDEN_DIM       = int(cfg["HIDDEN_DIM"])
    N_INNER_LAYERS   = int(cfg["N_INNER_LAYERS"])
    RECORD_INT       = int(cfg["RECORD_INTERVAL"])

    # DKAN Hyperparameters
    N_CHUNKS            = int(cfg["N_CHUNKS"])
    BLOCK_SIZE_FORWARD  = int(cfg["BLOCK_SIZE_FORWARD"])
    BLOCK_SIZE_BACKWARD = int(cfg["BLOCK_SIZE_BACKWARD"])
    TILE_SIZE_FORWARD   = int(cfg["TILE_SIZE_FORWARD"])
    TILE_SIZE_BACKWARD  = int(cfg["TILE_SIZE_BACKWARD"])
    INIT_SCALE          = float(cfg["INIT_SCALE"])

    # Schedule Parameters (Epoch-based)
    WARMUP_EPOCHS                 = int(cfg["WARMUP_EPOCHS"])
    PURE_MLP_EPOCHS               = int(cfg["PURE_MLP_EPOCHS"])
    PURE_MLP_LR                   = float(cfg["PURE_MLP_LR"])
    DKAN_TURN_ON_EPOCHS           = int(cfg["DKAN_TURN_ON_EPOCHS"])
    DKAN_TURN_ON_SCALE            = int(cfg["DKAN_TURN_ON_SCALE"])  # epochs over which dkan_weight ramps up
    DKAN_TURN_ON_CAP              = float(cfg["DKAN_TURN_ON_CAP"])
    DKAN_FROBENIUS_DECAY_EPOCHS   = int(cfg["DKAN_FROBENIUS_DECAY_EPOCHS"])
    DKAN_FROBENIUS_DECAY_SCALE    = int(cfg["DKAN_FROBENIUS_DECAY_SCALE"])  # epochs over which frobenius_weight decays
    FROBENIUS_WEIGHT_CAP          = float(cfg["FROBENIUS_WEIGHT_CAP"])
    DKAN_LEARNING_RATE_DECAY_EPOCHS = int(cfg["DKAN_LEARNING_RATE_DECAY_EPOCHS"])
    DKAN_LEARNING_RATE_DECAY_SCALE= int(cfg["DKAN_LEARNING_RATE_DECAY_SCALE"])  # epochs per LR halving
    INITIAL_FROBENIUS_WEIGHT      = float(cfg["INITIAL_FROBENIUS_WEIGHT"])
    DKAN_BASE_LR                  = float(cfg["DKAN_BASE_LR"])

    # ---------------- dataset ---------------
    def _load(path: str) -> np.ndarray:
        if not os.path.exists(path):
            raise FileNotFoundError(path)
        return np.load(path)

    if TASK == "INVARIANTS_DISTANCES":
        X_train = _load("./dataset_preparation/methane_invariants_distances_train.npy")
        X_val   = _load("./dataset_preparation/methane_invariants_distances_val.npy")
        ACTUAL_INPUT_DIM = 31
    elif TASK == "DISTANCES":
        X_train = _load("./dataset_preparation/methane_distances_train.npy")
        X_val   = _load("./dataset_preparation/methane_distances_val.npy")
        ACTUAL_INPUT_DIM = 10
    elif TASK == "HYDROGEN_DISPLACEMENTS":
        X_train = _load("./dataset_preparation/methane_hydrogen_displacements_train.npy")
        X_val   = _load("./dataset_preparation/methane_hydrogen_displacements_val.npy")
        ACTUAL_INPUT_DIM = 12
    elif TASK == "DISPLACEMENTS_INVARIANTS":
        X_train = _load("./dataset_preparation/methane_hydrogen_displacements_train.npy")
        X_val   = _load("./dataset_preparation/methane_hydrogen_displacements_val.npy")
        ACTUAL_INPUT_DIM = 34
    else:
        raise ValueError(f"Unsupported TASK '{TASK}'.")

    y_train = _load("./dataset_preparation/methane_energies_train_normalized.npy")
    y_val   = _load("./dataset_preparation/methane_energies_val_normalized.npy")
    if X_train.shape[0] != y_train.shape[0] or X_val.shape[0] != y_val.shape[0]:
        raise ValueError("Feature/target sample mismatch.")

    X_train, X_val = X_train.astype(np.float32), X_val.astype(np.float32)
    y_train, y_val = y_train.astype(np.float32), y_val.astype(np.float32)

    # ---------------- Calculate padded dimensions ---------
    PADDED_INPUT_DIM = ((ACTUAL_INPUT_DIM - 1) // TILE_SIZE_FORWARD + 1) * TILE_SIZE_FORWARD
    PADDED_OUTPUT_DIM = TILE_SIZE_FORWARD  # Since target dim is 1

    # ---------------- model & optim ---------
    student = FeedForwardNet(
        actual_input_dim=ACTUAL_INPUT_DIM,
        padded_input_dim=PADDED_INPUT_DIM,
        hidden_dim=HIDDEN_DIM,
        n_inner_layers=N_INNER_LAYERS,
        padded_output_dim=PADDED_OUTPUT_DIM,
        n_chunks=N_CHUNKS,
        block_size_forward=BLOCK_SIZE_FORWARD,
        block_size_backward=BLOCK_SIZE_BACKWARD,
        tile_size_forward=TILE_SIZE_FORWARD,
        tile_size_backward=TILE_SIZE_BACKWARD,
        init_scale=INIT_SCALE
    ).to(device)
    ddp_model = DDP(student, device_ids=[rank], output_device=rank)

    criterion  = nn.MSELoss()
    # Initial LR doesn't matter, set per epoch
    optimizer  = optim.Adam(ddp_model.parameters(), lr=0)
    # No scheduler needed, controlled by get_params

    # ---------------- training schedule function --------
    def get_params(epoch: int) -> Tuple[float, float, float]:
        if epoch <= WARMUP_EPOCHS:
            # Linear LR warmup phase
            lr = PURE_MLP_LR * (epoch / WARMUP_EPOCHS)
            dkan_weight = 0.0
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT
        elif epoch <= WARMUP_EPOCHS + PURE_MLP_EPOCHS:
            # Pure MLP phase (DKAN off)
            lr = PURE_MLP_LR
            dkan_weight = 0.0
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT
        elif epoch <= WARMUP_EPOCHS + PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS:
            # DKAN turn-on phase
            offset = epoch - (WARMUP_EPOCHS + PURE_MLP_EPOCHS)
            lr = DKAN_BASE_LR
            # Ramp up dkan_weight linearly over DKAN_TURN_ON_SCALE epochs
            dkan_weight = min((offset / DKAN_TURN_ON_SCALE) if DKAN_TURN_ON_SCALE > 0 else 1.0, DKAN_TURN_ON_CAP)
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT
        elif epoch <= WARMUP_EPOCHS + PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS + DKAN_FROBENIUS_DECAY_EPOCHS:
            # Frobenius weight decay phase
            offset = epoch - (WARMUP_EPOCHS + PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS)
            lr = DKAN_BASE_LR
            dkan_weight = DKAN_TURN_ON_CAP  # Use capped value
             # Decay frobenius_weight exponentially over DKAN_FROBENIUS_DECAY_SCALE epochs
            frobenius_weight = INITIAL_FROBENIUS_WEIGHT / (10 ** ((offset / DKAN_FROBENIUS_DECAY_SCALE) if DKAN_FROBENIUS_DECAY_SCALE > 0 else float('inf')))
            frobenius_weight = max(frobenius_weight, FROBENIUS_WEIGHT_CAP)  # Apply cap
        else:
            # Learning rate decay phase
            offset = epoch - (WARMUP_EPOCHS + PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS + DKAN_FROBENIUS_DECAY_EPOCHS)
            # Halve LR every DKAN_LEARNING_RATE_DECAY_SCALE epochs
            num_halvings = offset // DKAN_LEARNING_RATE_DECAY_SCALE if DKAN_LEARNING_RATE_DECAY_SCALE > 0 else 0
            lr = DKAN_BASE_LR * (0.5 ** num_halvings)
            dkan_weight = DKAN_TURN_ON_CAP  # Use capped value
            # Use final capped Frobenius weight
            frobenius_weight = max(INITIAL_FROBENIUS_WEIGHT / (10 ** ((DKAN_FROBENIUS_DECAY_EPOCHS / DKAN_FROBENIUS_DECAY_SCALE) if DKAN_FROBENIUS_DECAY_SCALE > 0 else float('inf'))), FROBENIUS_WEIGHT_CAP)

        return lr, dkan_weight, frobenius_weight

    TOTAL_EPOCHS = ( WARMUP_EPOCHS + PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS +
                     DKAN_FROBENIUS_DECAY_EPOCHS + DKAN_LEARNING_RATE_DECAY_EPOCHS )

    # ---------------- training loop ---------
    best_val, best_val_rmse, best_epoch, best_state = (
        float("inf"), float("inf"), 0, None
    )
    hist = {
        "epoch": [], "lr": [], "dkan_weight": [], "frobenius_weight": [],
        "train_loss": [], "val_loss": [], "train_rmse": [], "val_rmse": [],
        "full_train_loss": [],
    }
    t0 = time.time()

    for epoch in range(1, TOTAL_EPOCHS + 1):
        # ---- Get schedule params for epoch ----
        lr_now, dkan_weight_now, frobenius_weight_now = get_params(epoch)
        epoch_start_time = time.time() # Start timer for the current epoch

        # ---- Update optimizer LR ----
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_now

        # ---- train ----
        ddp_model.train()
        sum_pure_loss = torch.tensor(0.0, device=device)
        sum_fro_reg = torch.tensor(0.0, device=device)  # Keep track of fro reg for potential logging
        sum_full_loss = torch.tensor(0.0, device=device) # Accumulator for full loss
        n_samp   = torch.tensor(0,   device=device)

        for xb_t, yb_t in iterate_minibatches(X_train, y_train, BATCH_SIZE, True, TASK, "train", device):
            optimizer.zero_grad(set_to_none=True)
            outputs = ddp_model(xb_t, dkan_weight_now)  # Pass dkan_weight
            pure_loss = criterion(outputs, yb_t)
            fro_reg = ddp_model.module.get_frobenius_regularization()  # Get from underlying module
            loss = pure_loss + frobenius_weight_now * fro_reg
            loss.backward()
            optimizer.step()

            sum_pure_loss += pure_loss.detach() * xb_t.size(0)
            sum_fro_reg += fro_reg.detach() * xb_t.size(0)  # Accumulate fro reg too
            sum_full_loss += loss.detach() * xb_t.size(0)   # Accumulate full loss
            n_samp   += xb_t.size(0)

        dist.all_reduce(sum_pure_loss); dist.all_reduce(sum_fro_reg); dist.all_reduce(n_samp)
        dist.all_reduce(sum_full_loss) # All-reduce the full loss
        train_pure_loss = (sum_pure_loss / n_samp).item()
        train_full_loss = (sum_full_loss / n_samp).item() # Calculate average full loss
        # avg_fro_reg = (sum_fro_reg / n_samp).item() # Can log this if needed
        train_rmse = train_pure_loss**0.5  # RMSE based on pure MSE loss

        # ---- val ----
        ddp_model.eval()
        with torch.no_grad():
            sum_val_pure_loss = torch.tensor(0.0, device=device)
            n_val   = torch.tensor(0,   device=device)
            for xb_t, yb_t in iterate_minibatches(X_val, y_val, BATCH_SIZE, False, TASK, "val", device):
                outputs = ddp_model(xb_t, dkan_weight_now) # Pass dkan_weight for validation
                pure_loss_val = criterion(outputs, yb_t)
                sum_val_pure_loss += pure_loss_val * xb_t.size(0)
                n_val   += xb_t.size(0)

        dist.all_reduce(sum_val_pure_loss); dist.all_reduce(n_val)
        val_loss = (sum_val_pure_loss / n_val).item()  # Validation loss is pure MSE
        val_rmse = val_loss**0.5

        # ---- checkpoints ----
        if rank == 0 and val_loss < best_val:
            best_val, best_val_rmse, best_epoch = val_loss, val_rmse, epoch
            best_state = cpu_clone_state_dict(ddp_model.module)

        # ---- No Scheduler step ---- (LR controlled by get_params)

        if rank == 0 and (epoch == 1 or epoch % RECORD_INT == 0 or epoch == TOTAL_EPOCHS):
            epoch_duration = time.time() - epoch_start_time
            print(f"[Ep {epoch:03d}/{TOTAL_EPOCHS}] LR {lr_now:.2e} "
                  f"DKAN {dkan_weight_now:.2f} Fro {frobenius_weight_now:.2e} | "
                  f"Train Loss(Full {train_full_loss:.2e}, Pure {train_pure_loss:.2e}, RMSE {train_rmse:.2e})  "
                  f"Val Loss {val_loss:.2e} ({val_rmse:.2e}) "
                  f"| Ep Time {epoch_duration:.1f}s / Total {(time.time()-t0):.1f}s")

        hist["epoch"].append(epoch)
        hist["lr"].append(lr_now)
        hist["dkan_weight"].append(dkan_weight_now)
        hist["frobenius_weight"].append(frobenius_weight_now)
        hist["train_loss"].append(train_pure_loss)  # Log pure train loss
        hist["val_loss"].append(val_loss)
        hist["train_rmse"].append(train_rmse)
        hist["val_rmse"].append(val_rmse)
        hist["full_train_loss"].append(train_full_loss) # Log full train loss

    # ---------------- save artefacts --------
    if rank == 0:
        os.makedirs(args.calc_folder, exist_ok=True)

        np.savez(os.path.join(args.calc_folder, "history.npz"), **hist)
        torch.save(ddp_model.module.state_dict(),
                   os.path.join(args.calc_folder, "model_weights_last_epoch.pth"))
        if best_state is not None:
            torch.save(best_state,
                       os.path.join(args.calc_folder, "model_weights_best_val.pth"))

        with open(os.path.join(args.calc_folder, "summary.txt"), "w") as f:
            f.write(f"TASK: {TASK}\n")
            f.write(f"final_val_loss: {val_loss:.4e}\n")
            f.write(f"final_val_rmse: {val_rmse:.4e}\n")
            f.write(f"best_val_loss:  {best_val:.4e}  (epoch {best_epoch})\n")
            f.write(f"best_val_rmse:  {best_val_rmse:.4e}  (epoch {best_epoch})\n")
            f.write(f"wall_time_s:    {time.time()-t0:.1f}\n")
            f.write(f"num_gpus:       {world_size}\n")
            f.write(f"effect_batch:   {BATCH_SIZE * world_size}\n")
            # Add schedule info to summary
            f.write("\n--- Schedule ---\n")
            f.write(f"WARMUP_EPOCHS: {WARMUP_EPOCHS}\n")
            f.write(f"PURE_MLP_EPOCHS: {PURE_MLP_EPOCHS} (LR: {PURE_MLP_LR:.2e})\n")
            f.write(f"DKAN_TURN_ON_EPOCHS: {DKAN_TURN_ON_EPOCHS} (Scale: {DKAN_TURN_ON_SCALE}, Cap: {DKAN_TURN_ON_CAP:.2f})\n")
            f.write(f"DKAN_FROBENIUS_DECAY_EPOCHS: {DKAN_FROBENIUS_DECAY_EPOCHS} (Scale: {DKAN_FROBENIUS_DECAY_SCALE}, Cap: {FROBENIUS_WEIGHT_CAP:.2e}, Init: {INITIAL_FROBENIUS_WEIGHT:.2e})\n")
            f.write(f"DKAN_LEARNING_RATE_DECAY_EPOCHS: {DKAN_LEARNING_RATE_DECAY_EPOCHS} (Scale: {DKAN_LEARNING_RATE_DECAY_SCALE}, Base LR: {DKAN_BASE_LR:.2e})\n")
            f.write(f"TOTAL_EPOCHS: {TOTAL_EPOCHS}\n")
            f.write("\n--- DKAN Params ---\n")
            f.write(f"N_CHUNKS: {N_CHUNKS}\n")
            f.write(f"BLOCK_SIZE_FORWARD: {BLOCK_SIZE_FORWARD}\n")
            f.write(f"BLOCK_SIZE_BACKWARD: {BLOCK_SIZE_BACKWARD}\n")
            f.write(f"TILE_SIZE_FORWARD: {TILE_SIZE_FORWARD}\n")
            f.write(f"TILE_SIZE_BACKWARD: {TILE_SIZE_BACKWARD}\n")
            f.write(f"INIT_SCALE: {INIT_SCALE}\n")


        shutil.copy(args.calc_specification_path,
                    os.path.join(args.calc_folder,
                                 os.path.basename(args.calc_specification_path)))

    dist.destroy_process_group()


# --------------------------------------------------------------------------- #
# CLI
# --------------------------------------------------------------------------- #
def main() -> None:
    p = argparse.ArgumentParser(description="Train student network (DDP, methane dataset)")
    p.add_argument("--calc_specification_path", required=True)
    p.add_argument("--calc_folder", required=True)
    ddp_train(p.parse_args())


if __name__ == "__main__":
    main()
