# scripts/neural_network.py
# Selectable cases:
#   hyper  -> train hyper-network across priors
#   pooled -> pooled MAML across priors
#   grid   -> MAML library on integer grid priors
#   maml   -> per-prior MAML (oracle) for one test mu
#   all    -> run hyper + pooled + grid, then the comparison on a test prior
#
# Usage example:
#   python scripts/neural_network.py --case hyper --device cuda --hidden 32 --hyper_lr 5e-4 --hyper_steps 1600

import os
import math
import random
import argparse
import json
from datetime import datetime

import torch
from torch import nn
from torch.optim import Adam, AdamW
import torch.nn.functional as F

# Defaults (can be overridden via CLI)
FIRST_ORDER         = True   # first-order approx (no second-order terms)
USE_CRN_TRAIN       = True   # use common random numbers in TRAIN loops

IN_DIM              = 2
HIDDEN              = 32
OUT_DIM             = 1
NONLIN              = nn.LeakyReLU(0.1)
SIGMA_W             = 0.3
SIGMA_Y             = 0.1

INNER_STEPS         = 3
INNER_LR_INIT       = 0.01
ALPHA_MIN           = 1e-5
ALPHA_MAX           = 5e-2

TASKS_PER_BS        = 32
MAX_META_STEPS      = 400
MAX_META_STEPS_GRID = 300
META_LR             = 1e-2

HYPER_STEPS         = 1600
HYPER_LR            = 5e-4
INIT_REG            = 1e-4

VAL_INTERVAL        = 10
PATIENCE            = 20
PRINT_INTERVAL      = 20

NUM_RANDOM_PRIORS   = 200
GRID_SIDE           = 7

EVAL_TASKS          = 16
EVAL_BASE_SEED      = 12345
K_SPT_TRAIN         = 10
K_QRY_TRAIN         = 20
K_SPT_EVAL          = 20
K_QRY_EVAL          = 200
TASKS_EVAL_LARGE    = 128

CLIP_NORM           = 0.5


def set_seed(seed: int = 0):
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def make_generator(seed: int, dev: torch.device):
    try:
        gen = torch.Generator(device=dev)
    except TypeError:
        gen = torch.Generator()
    gen.manual_seed(seed)
    return gen


def crn_seed(step: int, task_idx: int, base: int = 777000):
    return base + step * 10007 + task_idx * 97


# Functional Learner MLP
class FunctionalMLP:
    # Two-layer MLP in functional form: y = W2 * f(W1 x + b1) + b2
    def __init__(self, in_dim=IN_DIM, hidden=HIDDEN, out_dim=OUT_DIM, nonlin=NONLIN):
        self.in_dim  = in_dim
        self.hidden  = hidden
        self.out_dim = out_dim
        self.nonlin  = nonlin
        self.num_params = hidden * in_dim + hidden + out_dim * hidden + out_dim

    def _unpack(self, theta):
        assert theta.ndim == 1, "theta must be a flat vector"
        h, d, o = self.hidden, self.in_dim, self.out_dim
        ofs = 0
        W1 = theta[ofs:ofs + h * d].view(h, d); ofs += h * d
        b1 = theta[ofs:ofs + h];                ofs += h
        W2 = theta[ofs:ofs + o * h].view(o, h); ofs += o * h
        b2 = theta[ofs:ofs + o];                ofs += o
        assert ofs == theta.numel(), "θ length mismatch"
        return W1, b1, W2, b2

    def forward(self, theta, x):
        W1, b1, W2, b2 = self._unpack(theta)
        h = self.nonlin(x @ W1.t() + b1)
        y = h @ W2.t() + b2
        return y.squeeze(-1)


# ----------------------------
# Parameter-Space Prior Mean m(mu)
# ----------------------------
class WeightPriorMean(nn.Module):
    """Frozen linear map m: R^2 -> R^P giving the MEAN parameter vector for prior mu."""
    def __init__(self, P):
        super().__init__()
        self.lin = nn.Linear(2, P)
        with torch.no_grad():
            self.lin.weight.normal_(0.0, 0.05)
            self.lin.bias.zero_()
        for p in self.parameters():
            p.requires_grad_(False)

    def forward(self, mu):  # mu: (2,) or (B,2)
        if mu.ndim == 1:
            mu = mu.unsqueeze(0)
            out = self.lin(mu).squeeze(0)
        else:
            out = self.lin(mu)
        return out


# ----------------------------
# Data & Task Sampler
# ----------------------------
def _randn(shape, dev, gen=None):
    if gen is None:
        return torch.randn(shape, device=dev)
    else:
        return torch.randn(shape, device=dev, generator=gen)


def sample_task_data(mean_map, fmlp, mu_w, k_spt, k_qry, sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                     gen: torch.Generator | None = None):
    """ONE task for prior mu_w.
    1) w ~ N(m(mu), sigma_w^2 I)
    2) X_s, X_q ~ N(0, I)
    3) y = f_w(x) + noise
    Returns: (X_s, y_s), (X_q, y_q)
    """
    dev = mu_w.device
    with torch.no_grad():
        m_mu = mean_map(mu_w)
    w_i  = m_mu + sigma_w * _randn(m_mu.shape, dev, gen=gen)

    X_s = _randn((k_spt, fmlp.in_dim), dev, gen=gen)
    X_q = _randn((k_qry, fmlp.in_dim), dev, gen=gen)
    with torch.no_grad():
        y_s_clean = fmlp.forward(w_i, X_s)
        y_q_clean = fmlp.forward(w_i, X_q)
    y_s = y_s_clean + sigma_y * _randn(y_s_clean.shape, dev, gen=gen)
    y_q = y_q_clean + sigma_y * _randn(y_q_clean.shape, dev, gen=gen)
    return (X_s, y_s), (X_q, y_q)


# ----------------------------
# Inner adaptation helper (K steps, per-parameter alpha)
# ----------------------------
def alpha_from_rho(rho: torch.Tensor) -> torch.Tensor:
    alpha_eff = F.softplus(rho)
    return torch.clamp(alpha_eff, ALPHA_MIN, ALPHA_MAX)


def inner_adapt(fmlp, theta_init, X_s, y_s, steps, alpha_vec, inner_lr_scalar, first_order):
    theta_i = theta_init
    for _ in range(steps):
        pred_s = fmlp.forward(theta_i, X_s)
        loss_s = ((pred_s - y_s) ** 2).mean()
        g = torch.autograd.grad(loss_s, theta_i, create_graph=not first_order)[0]
        step_vec = (alpha_vec * g) if (alpha_vec is not None) else (inner_lr_scalar * g)
        theta_i = theta_i - step_vec
    return theta_i


# ----------------------------
# Eval helpers
# ----------------------------
def eval_adaptation(mean_map, fmlp, theta_init, mu_w,
                    sigma_w=SIGMA_W, sigma_y=SIGMA_Y, inner_lr=INNER_LR_INIT,
                    tasks=EVAL_TASKS, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL,
                    steps=INNER_STEPS, alpha_vec=None, gen: torch.Generator | None = None,
                    first_order: bool = FIRST_ORDER):
    total_q = 0.0
    dev = mu_w.device
    for _ in range(tasks):
        theta0 = theta_init.clone().detach().to(dev).requires_grad_(True)
        (X_s, y_s), (X_q, y_q) = sample_task_data(
            mean_map, fmlp, mu_w, k_spt, k_qry, sigma_w=sigma_w, sigma_y=sigma_y, gen=gen
        )
        theta_i = inner_adapt(fmlp, theta0, X_s, y_s, steps, alpha_vec, inner_lr, first_order)
        with torch.no_grad():
            pred_q = fmlp.forward(theta_i, X_q)
            loss_q = ((pred_q - y_q) ** 2).mean()
            total_q += loss_q.item()
    return total_q / tasks


def eval_metrics_nn(mean_map, fmlp, theta_init, mu,
                    sigma_w=SIGMA_W, sigma_y=SIGMA_Y, inner_lr=INNER_LR_INIT,
                    tasks=EVAL_TASKS, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL,
                    steps=INNER_STEPS, alpha_vec=None, gen: torch.Generator | None = None,
                    first_order: bool = FIRST_ORDER):
    mse_list, r2_list, nmse_list = [], [], []
    dev = theta_init.device
    for _ in range(tasks):
        theta0 = theta_init.clone().detach().to(dev).requires_grad_(True)
        (X_s, y_s), (X_q, y_q) = sample_task_data(
            mean_map, fmlp, mu, k_spt, k_qry, sigma_w=sigma_w, sigma_y=sigma_y, gen=gen
        )
        theta_i = inner_adapt(fmlp, theta0, X_s, y_s, steps, alpha_vec, inner_lr, first_order)
        with torch.no_grad():
            y_hat = fmlp.forward(theta_i, X_q)
            err = y_hat - y_q
            mse = (err**2).mean().item()
            var_y = y_q.var(unbiased=True).item() + 1e-12
            sst = ((y_q - y_q.mean())**2).sum().item() + 1e-12
            r2 = 1.0 - (err.pow(2).sum().item() / sst)
            nmse = mse / var_y
        mse_list.append(mse); r2_list.append(r2); nmse_list.append(nmse)
    mse_avg = sum(mse_list)/len(mse_list)
    return {
        "MSE": mse_avg,
        "R2":  sum(r2_list)/len(r2_list),
        "nMSE":sum(nmse_list)/len(nmse_list),
        "ExcessOverNoise": max(0.0, mse_avg - sigma_y**2)
    }


# ----------------------------
# MAML baselines (with rho->alpha)
# ----------------------------
def run_maml(mean_map, fmlp, mu_w,
             sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
             inner_lr=INNER_LR_INIT, meta_lr=META_LR,
             tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
             k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN,
             steps=INNER_STEPS,
             val_priors=None,
             print_interval=PRINT_INTERVAL,
             first_order=FIRST_ORDER,
             use_crn_train=USE_CRN_TRAIN):
    """Per-prior MAML (one theta_init for THIS mu)."""
    dev = mu_w.device
    theta = nn.Parameter(0.1 * torch.randn(fmlp.num_params, device=dev))
    rho   = nn.Parameter(torch.full((fmlp.num_params,), math.log(math.exp(INNER_LR_INIT)-1), device=dev))
    opt   = Adam([theta, rho], lr=meta_lr)

    best_val = float('inf')
    patience, since = PATIENCE, 0
    tol = 1e-4

    train_loss_log, val_loss_log = [], []

    for step in range(1, max_meta_steps + 1):
        meta_loss = 0.0
        alpha_eff = alpha_from_rho(rho)
        for t in range(tasks_per_bs):
            gen = make_generator(crn_seed(step, t), dev) if use_crn_train else None
            (X_s, y_s), (X_q, y_q) = sample_task_data(
                mean_map, fmlp, mu_w, k_spt, k_qry, sigma_w=sigma_w, sigma_y=sigma_y, gen=gen
            )
            theta_i = inner_adapt(fmlp, theta, X_s, y_s, steps, alpha_eff, inner_lr, first_order)
            pred_q = fmlp.forward(theta_i, X_q)
            meta_loss = meta_loss + ((pred_q - y_q) ** 2).mean()
        meta_loss = meta_loss / tasks_per_bs

        if not torch.isfinite(meta_loss):
            with torch.no_grad():
                print("[NaN][MAML] stopping | alpha min/max:", float(alpha_eff.min()), float(alpha_eff.max()))
            break

        opt.zero_grad(); meta_loss.backward()
        torch.nn.utils.clip_grad_norm_(list([theta, rho]), CLIP_NORM)
        opt.step()

        if step % print_interval == 0 or step == 1:
            with torch.no_grad():
                a_min = float(alpha_eff.min()); a_max = float(alpha_eff.max())
            print(f"[MAML] step {step:4d} loss={meta_loss.item():.6f} | alpha[min,max]=({a_min:.3e},{a_max:.3e})")
            train_loss_log.append(meta_loss.item())

        # Validation (across provided val_priors)
        if val_priors and step % VAL_INTERVAL == 0:
            val_losses = [
                eval_adaptation(mean_map, fmlp, theta.detach(), mu,
                                sigma_w=sigma_w, sigma_y=sigma_y, inner_lr=inner_lr,
                                tasks=EVAL_TASKS, k_spt=k_spt, k_qry=k_qry, steps=steps,
                                alpha_vec=alpha_from_rho(rho.detach()), gen=None, first_order=first_order)
                for mu in val_priors
            ]
            avg_val = sum(val_losses)/len(val_losses)
            val_loss_log.append(avg_val)
            print(f"[MAML-VAL] step {step} avg_val={avg_val:.6f} best={best_val:.6f} since={since}")
            if avg_val < best_val - tol:
                best_val = avg_val; since = 0
            else:
                since += 1
            if since >= patience:
                break

    if len(train_loss_log) < (max_meta_steps // print_interval + 1):
        train_loss_log.append(meta_loss.item())

    return theta.detach(), alpha_from_rho(rho.detach()), train_loss_log, val_loss_log


def run_maml_across_priors(mean_map, fmlp, priors,
                           sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                           inner_lr=INNER_LR_INIT, meta_lr=META_LR,
                           tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
                           k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
                           val_priors=None,
                           print_interval=PRINT_INTERVAL,
                           first_order=FIRST_ORDER,
                           device=torch.device("cpu"),
                           use_crn_train=USE_CRN_TRAIN):
    """Pooled MAML: ONE shared theta_init across a LIST of priors."""
    theta = nn.Parameter(0.1 * torch.randn(fmlp.num_params, device=device))
    rho   = nn.Parameter(torch.full((fmlp.num_params,), math.log(math.exp(INNER_LR_INIT)-1), device=device))
    opt   = Adam([theta, rho], lr=meta_lr)

    best_val = float('inf')
    patience, since = PATIENCE, 0
    tol = 1e-4

    train_loss_log, val_loss_log = [], []

    for step in range(1, max_meta_steps + 1):
        meta_loss = 0.0
        alpha_eff = alpha_from_rho(rho)
        for t in range(tasks_per_bs):
            mu = priors[random.randrange(len(priors))]
            gen = make_generator(crn_seed(step, t), device) if use_crn_train else None
            (X_s, y_s), (X_q, y_q) = sample_task_data(
                mean_map, fmlp, mu, k_spt, k_qry, sigma_w=sigma_w, sigma_y=sigma_y, gen=gen
            )
            theta_i = inner_adapt(fmlp, theta, X_s, y_s, steps, alpha_eff, inner_lr, first_order)
            pred_q = fmlp.forward(theta_i, X_q)
            meta_loss = meta_loss + ((pred_q - y_q) ** 2).mean()
        meta_loss = meta_loss / tasks_per_bs

        if not torch.isfinite(meta_loss):
            with torch.no_grad():
                print("[NaN][Pooled] stopping | alpha min/max:", float(alpha_eff.min()), float(alpha_eff.max()))
            break

        opt.zero_grad(); meta_loss.backward()
        torch.nn.utils.clip_grad_norm_(list([theta, rho]), CLIP_NORM)
        opt.step()

        if step % print_interval == 0 or step == 1:
            with torch.no_grad():
                a_min = float(alpha_eff.min()); a_max = float(alpha_eff.max())
            print(f"[Pooled] step {step:4d} loss={meta_loss.item():.6f} | alpha[min,max]=({a_min:.3e},{a_max:.3e})")
            train_loss_log.append(meta_loss.item())

        if val_priors and step % VAL_INTERVAL == 0:
            val_losses = [
                eval_adaptation(mean_map, fmlp, theta.detach(), mu,
                                sigma_w=sigma_w, sigma_y=sigma_y, inner_lr=inner_lr,
                                tasks=EVAL_TASKS, k_spt=k_spt, k_qry=k_qry, steps=steps,
                                alpha_vec=alpha_from_rho(rho.detach()), gen=None, first_order=first_order)
                for mu in val_priors
            ]
            avg_val = sum(val_losses)/len(val_losses)
            val_loss_log.append(avg_val)
            print(f"[Pooled-VAL] step {step} avg_val={avg_val:.6f} best={best_val:.6f} since={since}")
            if avg_val < best_val - tol:
                best_val = avg_val; since = 0
            else:
                since += 1
            if since >= patience:
                break

    if len(train_loss_log) < (max_meta_steps // print_interval + 1):
        train_loss_log.append(meta_loss.item())

    return theta.detach(), alpha_from_rho(rho.detach()), train_loss_log, val_loss_log


# ----------------------------
# Hyper-network h_phi(mu) -> theta_init  (with shared rho->alpha)
# ----------------------------
class HyperNet(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64), nn.ReLU(),
            nn.Linear(64, out_dim),
        )
    def forward(self, mu):  # (2,) or (B,2)
        if mu.ndim == 1:
            mu = mu.unsqueeze(0)
            out = self.net(mu).squeeze(0)
        else:
            out = self.net(mu)
        return out


def train_hyper(mean_map, fmlp, train_priors, val_priors,
                sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                inner_lr=INNER_LR_INIT, hyper_lr=HYPER_LR,
                hyper_steps=HYPER_STEPS,
                priors_per_batch=4, tasks_per_prior=8,
                k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
                print_interval=PRINT_INTERVAL,
                first_order=FIRST_ORDER,
                device=torch.device("cpu"),
                use_crn_train=USE_CRN_TRAIN,
                init_reg=INIT_REG):
    hyper = HyperNet(out_dim=fmlp.num_params).to(device)
    rho   = nn.Parameter(torch.full((fmlp.num_params,), math.log(math.exp(inner_lr)-1), device=device))
    h_opt = AdamW(list(hyper.parameters()) + [rho], lr=hyper_lr, weight_decay=1e-4)

    best_val = float('inf')
    patience, since = PATIENCE, 0
    tol = 1e-4

    train_loss_log, val_loss_log = [], []

    for step in range(1, hyper_steps + 1):
        h_loss = 0.0
        alpha_eff = alpha_from_rho(rho)
        for p in range(priors_per_batch):
            mu = train_priors[random.randrange(len(train_priors))]
            theta_init = hyper(mu)
            for t in range(tasks_per_prior):
                gen = make_generator(crn_seed(step, p * 100 + t), device) if use_crn_train else None
                (X_s, y_s), (X_q, y_q) = sample_task_data(
                    mean_map, fmlp, mu, k_spt, k_qry, sigma_w=sigma_w, sigma_y=sigma_y, gen=gen
                )
                theta_i = inner_adapt(fmlp, theta_init, X_s, y_s, steps,
                                      alpha_eff, inner_lr, first_order)
                pred_q = fmlp.forward(theta_i, X_q)
                h_loss = h_loss + ((pred_q - y_q) ** 2).mean()
        h_loss = h_loss / (priors_per_batch * tasks_per_prior)
        # Mild regularization on emitted init
        h_loss = h_loss + init_reg * (theta_init.pow(2).mean())

        if not torch.isfinite(h_loss):
            with torch.no_grad():
                print("[NaN][Hyper] stopping | alpha min/max:",
                      float(alpha_eff.min()), float(alpha_eff.max()))
            break

        h_opt.zero_grad(); h_loss.backward()
        torch.nn.utils.clip_grad_norm_(list(hyper.parameters()) + [rho], CLIP_NORM)
        h_opt.step()

        if step % print_interval == 0 or step == 1:
            with torch.no_grad():
                a_min = float(alpha_eff.min()); a_max = float(alpha_eff.max())
            print(f"[Hyper] step {step:4d} loss={h_loss.item():.6f} "
                  f"| alpha[min,max]=({a_min:.3e},{a_max:.3e})")
            train_loss_log.append(h_loss.item())

        if val_priors and step % VAL_INTERVAL == 0:
            val_losses = []
            for mu in val_priors:
                with torch.no_grad():
                    theta_pred = hyper(mu)
                    a_eval = alpha_from_rho(rho.detach())
                val_losses.append(
                    eval_adaptation(mean_map, fmlp, theta_pred, mu,
                                    sigma_w=sigma_w, sigma_y=sigma_y, inner_lr=inner_lr,
                                    tasks=EVAL_TASKS, k_spt=k_spt, k_qry=k_qry, steps=steps,
                                    alpha_vec=a_eval, gen=None, first_order=first_order)
                )
            avg_val = sum(val_losses) / len(val_losses)
            val_loss_log.append(avg_val)
            print(f"[Hyper-VAL] step {step} avg_val={avg_val:.6f} best={best_val:.6f} since={since}")
            if avg_val < best_val - tol:
                best_val = avg_val; since = 0
            else:
                since += 1
            if since >= patience:
                break

    if len(train_loss_log) < (hyper_steps // print_interval + 1) and len(train_loss_log) > 0:
        train_loss_log.append(train_loss_log[-1])

    return hyper, alpha_from_rho(rho.detach()), train_loss_log, val_loss_log


# ----------------------------
# Misc helpers
# ----------------------------
def nearest_prior(mu: torch.Tensor, prior_list: list[torch.Tensor]) -> torch.Tensor:
    M = torch.stack(prior_list, dim=0)
    mu_b = mu.view(1, -1)
    d2 = (M - mu_b).pow(2).sum(dim=1)
    idx = torch.argmin(d2).item()
    return prior_list[idx]


def make_priors(device, num_random_priors=NUM_RANDOM_PRIORS, grid_side=GRID_SIDE):
    priors = [torch.empty(2, device=device).uniform_(-3, 3) for _ in range(num_random_priors)]
    random.shuffle(priors)
    n_val = int(0.10 * num_random_priors)
    n_test = int(0.10 * num_random_priors)
    val_priors   = priors[:n_val]
    test_priors  = priors[n_val:n_val + n_test]
    train_priors = priors[n_val + n_test:]

    coords = list(range(-3, 3 + 1))
    grid_priors = [torch.tensor([i, j], dtype=torch.float32, device=device)
                   for i in coords for j in coords]
    return train_priors, val_priors, test_priors, grid_priors


# ----------------------------
# Main CLI
# ----------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Neural-network meta-learning cases")
    p.add_argument("--case", type=str, default="hyper",
                   choices=["hyper", "pooled", "grid", "maml", "all"],
                   help="Which case to run")
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--save_dir", type=str, default="./outputs",
                   help="Where to save artifacts/logs (created if missing)")
    p.add_argument("--save_prefix", type=str, default="run",
                   help="Prefix for filenames")

    # Model/data
    p.add_argument("--hidden", type=int, default=HIDDEN)
    p.add_argument("--sigma_w", type=float, default=SIGMA_W)
    p.add_argument("--sigma_y", type=float, default=SIGMA_Y)
    p.add_argument("--first_order", action="store_true", default=FIRST_ORDER)
    p.add_argument("--no_first_order", action="store_true", help="Disable first-order approx")
    p.add_argument("--use_crn_train", action="store_true", default=USE_CRN_TRAIN)
    p.add_argument("--no_use_crn_train", action="store_true", help="Disable CRN during training")

    # Inner loop
    p.add_argument("--inner_steps", type=int, default=INNER_STEPS)
    p.add_argument("--inner_lr_init", type=float, default=INNER_LR_INIT)
    p.add_argument("--alpha_min", type=float, default=ALPHA_MIN)
    p.add_argument("--alpha_max", type=float, default=ALPHA_MAX)

    # Meta / Hyper
    p.add_argument("--tasks_per_bs", type=int, default=TASKS_PER_BS)
    p.add_argument("--max_meta_steps", type=int, default=MAX_META_STEPS)
    p.add_argument("--max_meta_steps_grid", type=int, default=MAX_META_STEPS_GRID)
    p.add_argument("--meta_lr", type=float, default=META_LR)

    p.add_argument("--hyper_steps", type=int, default=HYPER_STEPS)
    p.add_argument("--hyper_lr", type=float, default=HYPER_LR)
    p.add_argument("--init_reg", type=float, default=INIT_REG)

    # Priors / grids
    p.add_argument("--num_random_priors", type=int, default=NUM_RANDOM_PRIORS)
    p.add_argument("--grid_side", type=int, default=GRID_SIDE)

    # Eval
    p.add_argument("--eval_tasks", type=int, default=EVAL_TASKS)
    p.add_argument("--eval_base_seed", type=int, default=EVAL_BASE_SEED)
    p.add_argument("--k_spt_train", type=int, default=K_SPT_TRAIN)
    p.add_argument("--k_qry_train", type=int, default=K_QRY_TRAIN)
    p.add_argument("--k_spt_eval", type=int, default=K_SPT_EVAL)
    p.add_argument("--k_qry_eval", type=int, default=K_QRY_EVAL)
    p.add_argument("--tasks_eval_large", type=int, default=TASKS_EVAL_LARGE)

    # Misc
    p.add_argument("--skip_plots", action="store_true", default=True,
                   help="No plotting (safe for headless jobs)")
    p.add_argument("--mu_eval_x", type=float, default=None,
                   help="If set and case=maml, use this x coordinate for test prior")
    p.add_argument("--mu_eval_y", type=float, default=None,
                   help="If set and case=maml, use this y coordinate for test prior")

    return p.parse_args()


def main():
    args = parse_args()

    # Derive booleans from paired flags
    first_order = False if args.no_first_order else args.first_order
    use_crn_train = False if args.no_use_crn_train else args.use_crn_train

    # Device
    device = torch.device(args.device if torch.cuda.is_available() and args.device.startswith("cuda") else "cpu")
    print(f"Using device: {device}")

    # Seed
    set_seed(args.seed)

    # Allow global-like constants to be overridden by args for downstream helpers
    global HIDDEN, SIGMA_W, SIGMA_Y, INNER_STEPS, INNER_LR_INIT, ALPHA_MIN, ALPHA_MAX
    global TASKS_PER_BS, MAX_META_STEPS, MAX_META_STEPS_GRID, META_LR, HYPER_STEPS, HYPER_LR, INIT_REG
    global NUM_RANDOM_PRIORS, GRID_SIDE, EVAL_TASKS, EVAL_BASE_SEED, K_SPT_TRAIN, K_QRY_TRAIN, K_SPT_EVAL, K_QRY_EVAL, TASKS_EVAL_LARGE
    global FIRST_ORDER, USE_CRN_TRAIN

    HIDDEN              = args.hidden
    SIGMA_W             = args.sigma_w
    SIGMA_Y             = args.sigma_y
    INNER_STEPS         = args.inner_steps
    INNER_LR_INIT       = args.inner_lr_init
    ALPHA_MIN           = args.alpha_min
    ALPHA_MAX           = args.alpha_max

    TASKS_PER_BS        = args.tasks_per_bs
    MAX_META_STEPS      = args.max_meta_steps
    MAX_META_STEPS_GRID = args.max_meta_steps_grid
    META_LR             = args.meta_lr

    HYPER_STEPS         = args.hyper_steps
    HYPER_LR            = args.hyper_lr
    INIT_REG            = args.init_reg

    NUM_RANDOM_PRIORS   = args.num_random_priors
    GRID_SIDE           = args.grid_side

    EVAL_TASKS          = args.eval_tasks
    EVAL_BASE_SEED      = args.eval_base_seed
    K_SPT_TRAIN         = args.k_spt_train
    K_QRY_TRAIN         = args.k_qry_train
    K_SPT_EVAL          = args.k_spt_eval
    K_QRY_EVAL          = args.k_qry_eval
    TASKS_EVAL_LARGE    = args.tasks_eval_large

    FIRST_ORDER         = first_order
    USE_CRN_TRAIN       = use_crn_train

    # Prepare output dir
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    save_dir = os.path.join(args.save_dir, f"{args.save_prefix}-{args.case}-seed{args.seed}-{timestamp}")
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "flags.json"), "w") as f:
        json.dump(vars(args), f, indent=2)

    # Build learner + frozen mean-map
    fmlp = FunctionalMLP(in_dim=IN_DIM, hidden=HIDDEN, out_dim=OUT_DIM, nonlin=NONLIN)
    P = fmlp.num_params
    print(f"# Learner params (P): {P}")

    mean_map = WeightPriorMean(P).to(device)
    fmlp_device_proxy = torch.empty(1, device=device)  # for generators

    # Priors
    train_priors, val_priors, test_priors, grid_priors = make_priors(
        device=device, num_random_priors=NUM_RANDOM_PRIORS, grid_side=GRID_SIDE
    )
    print(f"Random Priors: train={len(train_priors)}, val={len(val_priors)}, test={len(test_priors)}")
    print(f"Grid priors: {len(grid_priors)} points ({GRID_SIDE}x{GRID_SIDE})")

    # ============== CASES ==============
    if args.case in ("hyper", "all"):
        print("\n[1/?] Hypernetwork Training started")
        hyper, alpha_h, hyper_train_log, hyper_val_log = train_hyper(
            mean_map, fmlp, train_priors, val_priors,
            sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
            inner_lr=INNER_LR_INIT, hyper_lr=HYPER_LR,
            hyper_steps=HYPER_STEPS,
            priors_per_batch=4, tasks_per_prior=8,
            k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
            print_interval=PRINT_INTERVAL,
            first_order=FIRST_ORDER,
            device=device,
            use_crn_train=USE_CRN_TRAIN,
            init_reg=INIT_REG
        )
        # Save hyper weights
        torch.save(hyper.state_dict(), os.path.join(save_dir, "hypernet.pt"))
        torch.save(alpha_h, os.path.join(save_dir, "alpha_h.pt"))
        
        #  print R^2 / MSE for the hypernetwork
        # Choose a held-out test prior (or let user specify)
        if args.mu_eval_x is not None and args.mu_eval_y is not None:
            mu_eval = torch.tensor([args.mu_eval_x, args.mu_eval_y],
                                   dtype=torch.float32, device=device)
        else:
            # default: last of the first 10 test priors (matches the other scripts)
            mu_eval = test_priors[:10][-1]

        with torch.no_grad():
            theta_hyp_eval = hyper(mu_eval)

        seed_eval = args.eval_base_seed + 999
        gen_eval = make_generator(seed_eval, dev=device)

        met_hyp_eval = eval_metrics_nn(
            mean_map, fmlp, theta_hyp_eval, mu_eval,
            tasks=TASKS_EVAL_LARGE,
            k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL, steps=INNER_STEPS,
            alpha_vec=alpha_h, gen=gen_eval, first_order=FIRST_ORDER
        )

        print("\n[Hyper] metrics on held-out prior:")
        print(f"  MSE={met_hyp_eval['MSE']:.6f}  R2={met_hyp_eval['R2']:.6f}  "
              f"nMSE={met_hyp_eval['nMSE']:.6f}  ExcessOverNoise={met_hyp_eval['ExcessOverNoise']:.6f}")

        # Save metrics artifact
        torch.save({
            "mu_eval": mu_eval.detach().cpu(),
            "metrics": met_hyp_eval,
            "alpha": alpha_h
        }, os.path.join(save_dir, "hyper_metrics.pt"))
        # ----------------------------------------------------
        
    else:
        hyper = None
        alpha_h = None

    if args.case in ("pooled", "all"):
        print("\n[?/??] Pooled MAML Training started")
        theta_maml_pooled, alpha_pooled, pooled_train_log, pooled_val_log = run_maml_across_priors(
            mean_map, fmlp, train_priors,
            sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
            inner_lr=INNER_LR_INIT, meta_lr=META_LR,
            tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
            k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
            val_priors=val_priors,
            print_interval=PRINT_INTERVAL,
            first_order=FIRST_ORDER,
            device=device,
            use_crn_train=USE_CRN_TRAIN
        )
        torch.save(theta_maml_pooled, os.path.join(save_dir, "theta_maml_pooled.pt"))
        torch.save(alpha_pooled, os.path.join(save_dir, "alpha_pooled.pt"))
    else:
        theta_maml_pooled = None
        alpha_pooled = None

    if args.case in ("grid", "all"):
        print("\n[??/??] Grid MAML Training started (per grid prior)")
        grid_maml_thetas = {}
        grid_maml_alphas = {}
        for mu in grid_priors:
            theta_init, alpha_grid, _, _ = run_maml(
                mean_map, fmlp, mu_w=mu,
                sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                inner_lr=INNER_LR_INIT, meta_lr=META_LR,
                tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS_GRID,
                k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
                val_priors=val_priors,
                print_interval=PRINT_INTERVAL,
                first_order=FIRST_ORDER,
                use_crn_train=USE_CRN_TRAIN
            )
            key = tuple(mu.tolist())
            grid_maml_thetas[key] = theta_init
            grid_maml_alphas[key] = alpha_grid
        torch.save({"theta": grid_maml_thetas, "alpha": grid_maml_alphas},
                   os.path.join(save_dir, "grid_maml_library.pt"))
    else:
        grid_maml_thetas = None
        grid_maml_alphas = None

    if args.case == "maml":
        # Per-prior MAML on a single test prior
        if args.mu_eval_x is not None and args.mu_eval_y is not None:
            mu_eval = torch.tensor([args.mu_eval_x, args.mu_eval_y], dtype=torch.float32, device=device)
        else:
            mu_eval = test_priors[:10][-1]
        theta_init_maml, alpha_maml, _, _ = run_maml(
            mean_map, fmlp, mu_eval, val_priors=None,
            sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
            inner_lr=INNER_LR_INIT, meta_lr=META_LR,
            tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
            k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
            print_interval=PRINT_INTERVAL,
            first_order=FIRST_ORDER,
            use_crn_train=USE_CRN_TRAIN
        )
        seed_last = args.eval_base_seed + 999
        gen_maml = make_generator(seed_last, dev=device)
        met_maml_last = eval_metrics_nn(
            mean_map, fmlp, theta_init_maml, mu_eval,
            tasks=TASKS_EVAL_LARGE, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL, steps=INNER_STEPS,
            alpha_vec=alpha_maml, gen=gen_maml, first_order=FIRST_ORDER
        )
        print("\n[MAML-per-mu] metrics on chosen prior:", met_maml_last)
        torch.save({"mu_eval": mu_eval.detach().cpu(), "metrics": met_maml_last,
                    "theta": theta_init_maml, "alpha": alpha_maml},
                   os.path.join(save_dir, "maml_single_results.pt"))

    # If 'all', print comparison on test priors (first up to 10)
    if args.case == "all":
        assert hyper is not None and theta_maml_pooled is not None and grid_maml_thetas is not None
        print("\nTEST priors: adaptation loss after K steps (lower is better)")
        for j, mu in enumerate(test_priors[:10]):
            # Oracle per-prior MAML
            theta_init_maml, alpha_maml, _, _ = run_maml(
                mean_map, fmlp, mu, val_priors=None,
                sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                inner_lr=INNER_LR_INIT, meta_lr=META_LR,
                tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
                k_spt=K_SPT_TRAIN, k_qry=K_QRY_TRAIN, steps=INNER_STEPS,
                print_interval=PRINT_INTERVAL,
                first_order=FIRST_ORDER,
                use_crn_train=USE_CRN_TRAIN
            )
            with torch.no_grad():
                theta_init_hyp = hyper(mu)
            mu_nn = nearest_prior(mu, grid_priors)
            key = tuple(mu_nn.tolist())
            theta_init_grid = grid_maml_thetas[key]
            alpha_grid_loc  = grid_maml_alphas[key]

            seed_j = args.eval_base_seed + j
            gen_maml = make_generator(seed_j, dev=device)
            loss_maml = eval_adaptation(mean_map, fmlp, theta_init_maml, mu,
                                        tasks=EVAL_TASKS, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL,
                                        steps=INNER_STEPS, alpha_vec=alpha_maml, gen=gen_maml,
                                        first_order=FIRST_ORDER)
            gen_hyp = make_generator(seed_j, dev=device)
            loss_hyp = eval_adaptation(mean_map, fmlp, theta_init_hyp, mu,
                                       tasks=EVAL_TASKS, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL,
                                       steps=INNER_STEPS, alpha_vec=alpha_h, gen=gen_hyp,
                                       first_order=FIRST_ORDER)
            gen_pool = make_generator(seed_j, dev=device)
            loss_pool = eval_adaptation(mean_map, fmlp, theta_maml_pooled, mu,
                                        tasks=EVAL_TASKS, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL,
                                        steps=INNER_STEPS, alpha_vec=alpha_pooled, gen=gen_pool,
                                        first_order=FIRST_ORDER)
            gen_grid = make_generator(seed_j, dev=device)
            loss_grid = eval_adaptation(mean_map, fmlp, theta_init_grid, mu,
                                        tasks=EVAL_TASKS, k_spt=K_SPT_EVAL, k_qry=K_QRY_EVAL,
                                        steps=INNER_STEPS, alpha_vec=alpha_grid_loc, gen=gen_grid,
                                        first_order=FIRST_ORDER)

            print(f"mu={mu.detach().cpu().numpy()} | MAML-per-mu: {loss_maml:.4f} | "
                  f"Hyper: {loss_hyp:.4f} | Pooled: {loss_pool:.4f} | Grid(NN): {loss_grid:.4f}")

    print(f"\nDone. Artifacts saved to: {save_dir}")


if __name__ == "__main__":
    main()

