# Linear tasks: y = x^T w_i + noise, with task weights w_i ~ N(mu_w, sigma_w^2 I_2)
# Compare:
#   (1) Hyper-network h(mu_w) -> theta_init
#   (2) Pooled MAML (one shared init across priors)
#   (3) Grid MAML (separate init per grid mu; pick NN(mu) at test)
#   (4) MAML (per-prior re-trained oracle on the test mu)

# Prints the four metrics dicts at the end:
#   "MAML (per-prior re-trained):", "Hyper:", "Pooled MAML (shared init):", "Grid MAML (NN to mu):"

import torch
from torch import nn
from torch.optim import Adam
import random
import math
import matplotlib.pyplot as plt


# Config
FIRST_ORDER    = False   # when set to True -> original used create_graph=True (second-order)
EVAL_TASKS     = 50      # validation uses 50 tasks
TASKS_PER_BS   = 100     # meta-batch for MAML
TASKS_PER_PRIOR= 50      # tasks per prior for hyper
MAX_META_STEPS = 1000    # MAML outer steps
HYPER_STEPS    = 1000    # Hyper outer steps
VAL_INTERVAL   = 50      # validate every 50 steps
PRINT_INTERVAL = 200     # log every 200 steps
PATIENCE       = 10      

SIGMA_W          = 1.0
SIGMA_Y          = 0.1
INNER_LR         = 0.01
META_LR          = 1e-2
HYPER_LR         = 1e-3

# Priors / grids
NUM_RANDOM_PRIORS = 200      # for random (train/val/test) splits
GRID_SIDE         = 7        # 7x7 integer grid from -3 to 3 => 49 points

# CRN (common random numbers) seed
BASE_EVAL_SEED   = 1337

EPS = 1e-12
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


def set_seed(seed: int = 0):
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def nearest_prior(mu: torch.Tensor, prior_list: list[torch.Tensor]) -> torch.Tensor:
    """Return the prior in prior_list closest (Euclidean) to mu."""
    M = torch.stack(prior_list, dim=0)  # [N,2] on device
    mu_b = mu.view(1, -1)               # [1,2]
    d2 = (M - mu_b).pow(2).sum(dim=1)
    idx = torch.argmin(d2).item()
    return prior_list[idx]

# Evaluation helpers (linear)
def eval_adaptation(theta_init, mu_w,
                    sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                    inner_lr=INNER_LR,
                    tasks=EVAL_TASKS, k_spt=10, k_qry=20):
    """
    Average query MSE after one gradient step starting from theta_init on tasks from prior mu_w.
    """
    total_q = 0.0
    p = theta_init.numel()
    dev = mu_w.device
    for _ in range(tasks):
        theta = theta_init.clone().detach().to(dev).requires_grad_(True)
        # sample task
        w_i = mu_w + sigma_w * torch.randn_like(mu_w)
        # support
        X_s = torch.randn(k_spt, p, device=dev)
        y_s = X_s @ w_i + sigma_y * torch.randn(k_spt, device=dev)
        loss_s = ((X_s @ theta - y_s) ** 2).mean()
        grad   = torch.autograd.grad(loss_s, theta)[0]
        theta_i = theta - inner_lr * grad
        # query
        X_q = torch.randn(k_qry, p, device=dev)
        y_q = X_q @ w_i + sigma_y * torch.randn(k_qry, device=dev)
        with torch.no_grad():
            loss_q = ((X_q @ theta_i - y_q) ** 2).mean()
            total_q += loss_q.item()
    return total_q / tasks

def eval_metrics_lr(theta_init, mu, sigma_w=SIGMA_W, sigma_y=SIGMA_Y, inner_lr=INNER_LR,
                    tasks=EVAL_TASKS, k_spt=10, k_qry=20):
    """
    Returns averaged metrics after one-step adaptation:
      - MSE
      - R2_agg (aggregated R^2 across episodes; stable)
      - nMSE (normalized by Var(y))
      - ExcessOverNoise (MSE - sigma_y^2, floored at 0)
    """
    p = theta_init.numel()
    dev = mu.device
    mse_list, nmse_list = [], []
    SSE_total, SST_total = 0.0, 0.0

    for _ in range(tasks):
        w_i = mu + sigma_w * torch.randn_like(mu)
        X_s = torch.randn(k_spt, p, device=dev)
        y_s = X_s @ w_i + sigma_y * torch.randn(k_spt, device=dev)

        theta = theta_init.clone().detach().to(dev).requires_grad_(True)
        loss_s = ((X_s @ theta - y_s) ** 2).mean()
        grad = torch.autograd.grad(loss_s, theta)[0]
        theta_i = theta - inner_lr * grad

        X_q = torch.randn(k_qry, p, device=dev)
        y_q = X_q @ w_i + sigma_y * torch.randn(k_qry, device=dev)
        y_hat = X_q @ theta_i

        err = (y_hat - y_q)
        mse = (err**2).mean().item()
        var_y = float(y_q.var(unbiased=True).item())
        nmse = mse / (var_y + EPS)
        mse_list.append(mse)
        nmse_list.append(nmse)

        SSE_total += float((err**2).sum().item())
        SST_total += float(((y_q - y_q.mean())**2).sum().item())

    mse_avg = sum(mse_list)/len(mse_list)
    R2_agg = 1.0 - SSE_total / (SST_total + EPS)
    return {
        "MSE": mse_avg,
        "R2_agg": R2_agg,
        "nMSE": sum(nmse_list)/len(nmse_list),
        "ExcessOverNoise": max(0.0, mse_avg - sigma_y**2)
    }


# MAML baselines (linear)
def run_maml(mu_w,
             sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
             inner_lr=INNER_LR, meta_lr=META_LR,
             tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
             p=2, k_spt=10, k_qry=10,
             val_priors=None,
             print_interval=PRINT_INTERVAL,
             first_order=FIRST_ORDER):
    """
    Per-prior MAML: learn a theta_init for THIS mu_w
    """
    dev = mu_w.device
    theta = nn.Parameter(torch.randn(p, device=dev))
    opt   = Adam([theta], lr=meta_lr)

    best_val = float('inf')
    patience, steps_since = PATIENCE, 0
    val_interval, tol = VAL_INTERVAL, 1e-4

    train_loss_log, val_loss_log = [], []

    for step in range(1, max_meta_steps + 1):
        meta_loss = 0.0
        for _ in range(tasks_per_bs):
            w_i = mu_w + sigma_w * torch.randn(p, device=dev)
            X_s = torch.randn(k_spt, p, device=dev)
            y_s = X_s @ w_i + sigma_y * torch.randn(k_spt, device=dev)
            loss_s = ((X_s @ theta - y_s) ** 2).mean()
            grad   = torch.autograd.grad(loss_s, theta, create_graph=not first_order)[0]
            theta_i = theta - inner_lr * grad
            X_q = torch.randn(k_qry, p, device=dev)
            y_q = X_q @ w_i + sigma_y * torch.randn(k_qry, device=dev)
            meta_loss += ((X_q @ theta_i - y_q) ** 2).mean()

        meta_loss = meta_loss / tasks_per_bs
        opt.zero_grad(); meta_loss.backward(); opt.step()

        if step % print_interval == 0 or step == 1:
            train_loss_log.append(meta_loss.item())

        if val_priors and step % val_interval == 0:
            val_losses = [eval_adaptation(theta.detach(), mu,
                                          sigma_w=sigma_w, sigma_y=sigma_y,
                                          inner_lr=inner_lr,
                                          tasks=EVAL_TASKS, k_spt=k_spt, k_qry=k_qry)
                          for mu in val_priors]
            avg_val = sum(val_losses)/len(val_losses)
            val_loss_log.append(avg_val)
            if avg_val < best_val - tol:
                best_val, steps_since = avg_val, 0
            else:
                steps_since += 1
            if steps_since >= patience:
                break

    if len(train_loss_log) < (max_meta_steps // print_interval + 1):
        train_loss_log.append(meta_loss.item())

    return theta.detach(), train_loss_log, val_loss_log

def run_maml_across_priors(priors,
                           sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                           inner_lr=INNER_LR, meta_lr=META_LR,
                           tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
                           p=2, k_spt=10, k_qry=10,
                           val_priors=None,
                           print_interval=PRINT_INTERVAL,
                           first_order=FIRST_ORDER,
                           device=device):
    """
    Pooled MAML: one shared theta_init across a LIST of priors.
    """
    theta = nn.Parameter(torch.randn(p, device=device))
    opt   = Adam([theta], lr=meta_lr)

    best_val = float('inf')
    patience, steps_since = PATIENCE, 0
    val_interval, tol = VAL_INTERVAL, 1e-4

    train_loss_log, val_loss_log = [], []

    for step in range(1, max_meta_steps + 1):
        meta_loss = 0.0
        for _ in range(tasks_per_bs):
            mu = priors[random.randrange(len(priors))]
            w_i = mu + sigma_w * torch.randn(p, device=device)
            X_s = torch.randn(k_spt, p, device=device)
            y_s = X_s @ w_i + sigma_y * torch.randn(k_spt, device=device)
            loss_s = ((X_s @ theta - y_s) ** 2).mean()
            grad   = torch.autograd.grad(loss_s, theta, create_graph=not first_order)[0]
            theta_i = theta - inner_lr * grad
            X_q = torch.randn(k_qry, p, device=device)
            y_q = X_q @ w_i + sigma_y * torch.randn(k_qry, device=device)
            meta_loss += ((X_q @ theta_i - y_q) ** 2).mean()

        meta_loss = meta_loss / tasks_per_bs
        opt.zero_grad(); meta_loss.backward(); opt.step()

        if step % print_interval == 0 or step == 1:
            train_loss_log.append(meta_loss.item())

        if val_priors and step % val_interval == 0:
            val_losses = [eval_adaptation(theta.detach(), mu,
                                          sigma_w=sigma_w, sigma_y=sigma_y,
                                          inner_lr=inner_lr,
                                          tasks=EVAL_TASKS, k_spt=k_spt, k_qry=k_qry)
                          for mu in val_priors]
            avg_val = sum(val_losses)/len(val_losses)
            val_loss_log.append(avg_val)
            if avg_val < best_val - tol:
                best_val, steps_since = avg_val, 0
            else:
                steps_since += 1
            if steps_since >= patience:
                break

    if len(train_loss_log) < (max_meta_steps // print_interval + 1):
        train_loss_log.append(meta_loss.item())

    return theta.detach(), train_loss_log, val_loss_log


# Hyper-network h_phi(mu) -> theta_init (linear)
class HyperNet(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64), nn.ReLU(),
            nn.Linear(64, out_dim),
        )
    def forward(self, mu):  # mu: (2,) or (B,2)
        if mu.ndim == 1:
            mu = mu.unsqueeze(0)
            out = self.net(mu).squeeze(0)
        else:
            out = self.net(mu)
        return out

def train_hyper(train_priors, val_priors,
                sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
                inner_lr=INNER_LR, hyper_lr=HYPER_LR,
                hyper_steps=HYPER_STEPS,
                priors_per_batch=4, tasks_per_prior=TASKS_PER_PRIOR,
                k_spt=10, k_qry=10,
                print_interval=PRINT_INTERVAL,
                first_order=FIRST_ORDER,
                device=device):
    """
    Train a hyper-network to map mu in R^2 -> theta_init in R^2 (linear learner).
    """
    hyper = HyperNet(out_dim=2).to(device)
    h_opt = Adam(hyper.parameters(), lr=hyper_lr)

    best_val = float('inf')
    patience, steps_since = PATIENCE, 0
    val_interval, tol = VAL_INTERVAL, 1e-4

    train_loss_log, val_loss_log = [], []

    for step in range(1, hyper_steps + 1):
        h_loss = 0.0
        for _ in range(priors_per_batch):
            mu = train_priors[random.randrange(len(train_priors))]
            theta_init = hyper(mu)  # (2,)
            for _ in range(tasks_per_prior):
                w_i = mu + sigma_w * torch.randn(2, device=device)
                X_s = torch.randn(k_spt, 2, device=device)
                y_s = X_s @ w_i + sigma_y * torch.randn(k_spt, device=device)
                loss_s = ((X_s @ theta_init - y_s) ** 2).mean()
                grad   = torch.autograd.grad(loss_s, theta_init, create_graph=not first_order)[0]
                theta_i = theta_init - inner_lr * grad
                X_q = torch.randn(k_qry, 2, device=device)
                y_q = X_q @ w_i + sigma_y * torch.randn(k_qry, device=device)
                h_loss = h_loss + ((X_q @ theta_i - y_q) ** 2).mean()

        h_loss = h_loss / (priors_per_batch * tasks_per_prior)
        h_opt.zero_grad(); h_loss.backward(); h_opt.step()

        if step % print_interval == 0 or step == 1:
            train_loss_log.append(h_loss.item())
            print(f'step: {step}, hyperparameter loss: {h_loss:.4f}')

        if val_priors and step % val_interval == 0:
            val_losses = []
            for mu in val_priors:
                with torch.no_grad():
                    theta_pred = hyper(mu)
                val_losses.append(eval_adaptation(theta_pred, mu,
                                                  sigma_w=sigma_w, sigma_y=sigma_y,
                                                  inner_lr=inner_lr,
                                                  tasks=EVAL_TASKS, k_spt=k_spt, k_qry=k_qry))
            avg_val = sum(val_losses)/len(val_losses)
            val_loss_log.append(avg_val)
            if avg_val < best_val - tol:
                best_val, steps_since = avg_val, 0
            else:
                steps_since += 1
            if steps_since >= patience:
                break

    if len(train_loss_log) < (hyper_steps // print_interval + 1):
        train_loss_log.append(h_loss.item())
    return hyper, train_loss_log, val_loss_log


# Experiments
if __name__ == "__main__":
    set_seed(0)

    # Random priors (train/val/test)
    priors = [torch.empty(2, device=device).uniform_(-3, 3) for _ in range(NUM_RANDOM_PRIORS)]
    random.shuffle(priors)
    n_val = int(0.10 * NUM_RANDOM_PRIORS)
    n_test = int(0.10 * NUM_RANDOM_PRIORS)
    val_priors   = priors[:n_val]
    test_priors  = priors[n_val:n_val + n_test]
    train_priors = priors[n_val + n_test:]
    print(f"Random Priors: train={len(train_priors)}, val={len(val_priors)}, test={len(test_priors)}")

    # Grid priors for Grid MAML (e.g., 7x7 integer grid over [-3,3]^2)
    coords = list(range(-3, 3 + 1))  # [-3..3]
    grid_priors = [torch.tensor([i, j], dtype=torch.float32, device=device)
                   for i in coords for j in coords]
    print(f"Grid priors: {len(grid_priors)} points ({GRID_SIDE}x{GRID_SIDE})")

    # 1) Hyper-network (across priors)
    print("Hypernetwork Training started")
    hyper, hyper_train_log, hyper_val_log = train_hyper(
        train_priors, val_priors,
        sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
        inner_lr=INNER_LR, hyper_lr=HYPER_LR,
        hyper_steps=HYPER_STEPS,
        priors_per_batch=4, tasks_per_prior=TASKS_PER_PRIOR,
        k_spt=10, k_qry=10,
        print_interval=PRINT_INTERVAL,
        first_order=FIRST_ORDER,
        device=device
    )

    # 2) Pooled MAML (one shared init across random train_priors)
    print("Pooled MAML Training started (shared init across random train priors)")
    theta_maml_pooled, pooled_train_log, pooled_val_log = run_maml_across_priors(
        train_priors,
        sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
        inner_lr=INNER_LR, meta_lr=META_LR,
        tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
        p=2, k_spt=10, k_qry=10,
        val_priors=val_priors,
        print_interval=PRINT_INTERVAL,
        first_order=FIRST_ORDER,
        device=device
    )

    # 3) Grid MAML (separate init per grid mu)
    print("Grid MAML Training started (separate init per grid prior)")
    grid_maml_thetas = {}
    grid_maml_logs = {}
    for mu in grid_priors:
        theta_init, train_log, val_log = run_maml(
            mu_w=mu,
            sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
            inner_lr=INNER_LR, meta_lr=META_LR,
            tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
            p=2, k_spt=10, k_qry=10,
            val_priors=val_priors,
            print_interval=PRINT_INTERVAL,
            first_order=FIRST_ORDER
        )
        grid_maml_thetas[tuple(mu.tolist())] = theta_init
        grid_maml_logs[tuple(mu.tolist())] = (train_log, val_log)

    # plot hyper & pooled curves
    plt.figure(figsize=(10,4))
    # Hyper curves
    plt.subplot(1,2,1)
    x_h_train = [i * PRINT_INTERVAL + 1 for i in range(len(hyper_train_log))]
    x_h_val   = [i * VAL_INTERVAL for i in range(1, len(hyper_val_log) + 1)]
    plt.plot(x_h_train, hyper_train_log, label="Hyper Train Loss")
    plt.plot(x_h_val,   hyper_val_log,   label="Hyper Val Loss")
    plt.xlabel("Hyper-step"); plt.ylabel("Loss"); plt.title("Hyper-network Loss"); plt.legend()
    # Pooled MAML curves
    plt.subplot(1,2,2)
    x_p_train = [i * PRINT_INTERVAL + 1 for i in range(len(pooled_train_log))]
    x_p_val   = [i * VAL_INTERVAL for i in range(1, len(pooled_val_log) + 1)]
    plt.plot(x_p_train, pooled_train_log, label="Pooled MAML Train Loss")
    plt.plot(x_p_val,   pooled_val_log,   label="Pooled MAML Val Loss")
    plt.xlabel("Meta-step"); plt.ylabel("Loss"); plt.title("Pooled MAML Loss"); plt.legend()
    plt.tight_layout(); plt.show()


    # Average training-set adaptation losses (Hyper vs Pooled vs Grid)
    total_hyp_loss = 0.0
    total_pooled_loss = 0.0
    total_grid_loss = 0.0

    for i, mu in enumerate(train_priors):
        seed_i = BASE_EVAL_SEED + i
        # Hyper
        with torch.no_grad():
            theta_hyp = hyper(mu)
        set_seed(seed_i)
        total_hyp_loss += eval_adaptation(theta_hyp, mu, tasks=EVAL_TASKS)
        # Pooled
        set_seed(seed_i)
        total_pooled_loss += eval_adaptation(theta_maml_pooled, mu, tasks=EVAL_TASKS)
        # Grid NN
        mu_nn = nearest_prior(mu, grid_priors)
        theta_grid = grid_maml_thetas[tuple(mu_nn.tolist())]
        set_seed(seed_i)
        total_grid_loss += eval_adaptation(theta_grid, mu, tasks=EVAL_TASKS)

    avg_hyp_loss     = total_hyp_loss    / len(train_priors)
    avg_pooled_loss  = total_pooled_loss / len(train_priors)
    avg_grid_loss    = total_grid_loss   / len(train_priors)

    print("\n=== Average Training-set Adaptation Loss (MSE) ===")
    print(f"Hyper (across priors)      = {avg_hyp_loss:.4f}")
    print(f"Pooled MAML (shared init)   = {avg_pooled_loss:.4f}")
    print(f"Grid MAML (NN to mu)        = {avg_grid_loss:.4f}")

    # Test priors: compare all four (including per-prior MAML oracle)
    print("\nTEST priors (CRN): adaptation loss after one step (lower is better)")
    for j, mu in enumerate(test_priors[:10]):
        # per-prior MAML retrained on this exact test mu
        theta_init_maml, _, _ = run_maml(
            mu_w=mu, val_priors=None,
            sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
            inner_lr=INNER_LR, meta_lr=META_LR,
            tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
            p=2, k_spt=10, k_qry=10, print_interval=PRINT_INTERVAL,
            first_order=FIRST_ORDER
        )
        with torch.no_grad():
            theta_init_hyp = hyper(mu)
        theta_init_pooled = theta_maml_pooled
        mu_nn = nearest_prior(mu, grid_priors)
        theta_init_grid = grid_maml_thetas[tuple(mu_nn.tolist())]

        seed_j = BASE_EVAL_SEED + 10_000 + j
        set_seed(seed_j)
        loss_maml = eval_adaptation(theta_init_maml, mu, tasks=EVAL_TASKS)
        set_seed(seed_j)
        loss_hyp  = eval_adaptation(theta_init_hyp,  mu, tasks=EVAL_TASKS)
        set_seed(seed_j)
        loss_pool = eval_adaptation(theta_init_pooled, mu, tasks=EVAL_TASKS)
        set_seed(seed_j)
        loss_grid = eval_adaptation(theta_init_grid,  mu, tasks=EVAL_TASKS)

        print(f"μ={mu.detach().cpu().numpy()} | "
              f"MAML-per-mu: {loss_maml:.4f} | Hyper: {loss_hyp:.4f} | "
              f"Pooled: {loss_pool:.4f} | Grid(NN): {loss_grid:.4f}")

    # Detailed metrics on one test prior (larger eval)
    mu_last = test_priors[:10][-1]

    # Retrain oracle per-prior MAML for that prior
    theta_init_maml_last, _, _ = run_maml(
        mu_w=mu_last, val_priors=None,
        sigma_w=SIGMA_W, sigma_y=SIGMA_Y,
        inner_lr=INNER_LR, meta_lr=META_LR,
        tasks_per_bs=TASKS_PER_BS, max_meta_steps=MAX_META_STEPS,
        p=2, k_spt=10, k_qry=10, print_interval=PRINT_INTERVAL,
        first_order=FIRST_ORDER
    )
    with torch.no_grad():
        theta_hyp_last = hyper(mu_last)
    theta_pool_last = theta_maml_pooled
    mu_nn_last = nearest_prior(mu_last, grid_priors)
    theta_grid_last = grid_maml_thetas[tuple(mu_nn_last.tolist())]

    # Larger eval
    k_spt_eval = 10
    k_qry_eval = 200
    tasks_eval = 128
    seed_last = BASE_EVAL_SEED + 999

    set_seed(seed_last)
    met_maml_last = eval_metrics_lr(theta_init_maml_last, mu_last,
                                    tasks=tasks_eval, k_spt=k_spt_eval, k_qry=k_qry_eval)
    set_seed(seed_last)
    met_hyp_last  = eval_metrics_lr(theta_hyp_last,  mu_last,
                                    tasks=tasks_eval, k_spt=k_spt_eval, k_qry=k_qry_eval)
    set_seed(seed_last)
    met_pool_last = eval_metrics_lr(theta_pool_last, mu_last,
                                    tasks=tasks_eval, k_spt=k_spt_eval, k_qry=k_qry_eval)
    set_seed(seed_last)
    met_grid_last = eval_metrics_lr(theta_grid_last, mu_last,
                                    tasks=tasks_eval, k_spt=k_spt_eval, k_qry=k_qry_eval)

    print("\nDetailed metrics on one test prior (CRN, larger eval):")
    print("MAML (per-prior re-trained):", met_maml_last)
    print("Hyper:", met_hyp_last)
    print("Pooled MAML (shared init):", met_pool_last)
    print("Grid MAML (NN to mu):", met_grid_last)
