#!/usr/bin/env python3
# MNIST_Lenet5_SVI_PBT_pbt_then_retrain.py
# Run PBT once (find best hyperparams) -> retrain best config for 5 seeds (50 epochs) -> report mean/std.
# Adapted from your uploaded MNIST_Lenet300_SVI_PBT reference. :contentReference[oaicite:1]{index=1}

import os
import random
import tempfile
import numpy as np
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining

# Ray AIR compat
try:
    from ray.air import session
except Exception:
    from ray.tune import session  # type: ignore
try:
    from ray.air.checkpoint import Checkpoint
except Exception:
    try:
        from ray.tune import Checkpoint
    except Exception:
        Checkpoint = None

# ---------- Determinism ----------
torch.backends.cudnn.benchmark = True
torch.set_num_threads(1)
os.environ.setdefault("OMP_NUM_THREADS", "1")

# ---------- Small helpers ----------
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ---------- Bayesian masked layers (conv + dense) ----------
def kl_gaussian(mu: torch.Tensor, rho: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    sigma = torch.log1p(torch.exp(rho))
    kl_elem = sigma.pow(2) + mu.pow(2) - 1.0 - 2.0 * torch.log(sigma + 1e-12)
    return 0.5 * torch.sum(kl_elem * mask)

class MaskedLRDense(nn.Module):
    def __init__(self, fin: int, fout: int):
        super().__init__()
        self.mu_w  = nn.Parameter(0.05 * torch.randn(fout, fin))
        self.rho_w = nn.Parameter(-3.0 * torch.ones(fout, fin))
        self.mu_b  = nn.Parameter(torch.zeros(fout))
        self.rho_b = nn.Parameter(-3.0 * torch.ones(fout))
        self.register_buffer("mask_w", torch.ones(fout, fin))
        self.register_buffer("mask_b", torch.ones(fout))

    def forward(self, x: torch.Tensor, kl_scale: float):
        mu_w = self.mu_w * self.mask_w
        mu_b = self.mu_b * self.mask_b
        sigma_w = torch.log1p(torch.exp(self.rho_w))
        sigma_b = torch.log1p(torch.exp(self.rho_b))

        mu_z  = F.linear(x, mu_w, mu_b)
        var_z = F.linear(x.pow(2), (sigma_w ** 2) * (self.mask_w ** 2), None) + (sigma_b ** 2)
        eps   = torch.randn_like(mu_z)
        z     = mu_z + eps * torch.sqrt(var_z + 1e-8)

        kl = kl_gaussian(self.mu_w, self.rho_w, self.mask_w) + kl_gaussian(self.mu_b, self.rho_b, self.mask_b)
        return z, kl * kl_scale

class MaskedLRConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel_size: int, stride: int = 1, padding: int = 0):
        super().__init__()
        k = kernel_size if isinstance(kernel_size, int) else kernel_size[0]
        self.mu_w  = nn.Parameter(0.05 * torch.randn(out_ch, in_ch, k, k))
        self.rho_w = nn.Parameter(-3.0 * torch.ones(out_ch, in_ch, k, k))
        self.mu_b  = nn.Parameter(torch.zeros(out_ch))
        self.rho_b = nn.Parameter(-3.0 * torch.ones(out_ch))
        self.register_buffer("mask_w", torch.ones(out_ch, in_ch, k, k))
        self.register_buffer("mask_b", torch.ones(out_ch))
        self.stride = stride
        self.padding = padding

    def forward(self, x: torch.Tensor, kl_scale: float):
        mu_w = self.mu_w * self.mask_w
        mu_b = self.mu_b * self.mask_b
        sigma_w = torch.log1p(torch.exp(self.rho_w))
        sigma_b = torch.log1p(torch.exp(self.rho_b))

        mu_z = F.conv2d(x, mu_w, bias=mu_b, stride=self.stride, padding=self.padding)
        var_z = F.conv2d(x.pow(2), (sigma_w ** 2) * (self.mask_w ** 2), bias=None, stride=self.stride, padding=self.padding)
        var_z = var_z + (sigma_b ** 2).view(1, -1, 1, 1)
        eps = torch.randn_like(mu_z)
        z = mu_z + eps * torch.sqrt(var_z + 1e-8)

        kl = kl_gaussian(self.mu_w, self.rho_w, self.mask_w) + kl_gaussian(self.mu_b, self.rho_b, self.mask_b)
        return z, kl * kl_scale

# ---------- LeNet-5 Bayesian masked model ----------
class MaskedBayesianLeNet5(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.conv1 = MaskedLRConv(1, 6, kernel_size=5, stride=1, padding=0)   # 28->24
        self.conv2 = MaskedLRConv(6, 16, kernel_size=5, stride=1, padding=0)  # ...
        FLAT_IN = 16 * 4 * 4
        self.fc1 = MaskedLRDense(FLAT_IN, 120)
        self.fc2 = MaskedLRDense(120, 84)
        self.fc3 = MaskedLRDense(84, 10)

        self.max_beta      = float(config["max_beta"])
        self.warmup_epochs = int(config["warmup_epochs"])
        self.kl_scales = {
            "conv1": float(config.get("kl_scale_conv1", 1.0)),
            "conv2": float(config.get("kl_scale_conv2", 1.0)),
            "fc1":   float(config.get("kl_scale_fc1", 1.0)),
            "fc2":   float(config.get("kl_scale_fc2", 1.0)),
            "fc3":   float(config.get("kl_scale_fc3", 1.0)),
        }

    def forward(self, x: torch.Tensor, epoch: int):
        warm = min(1.0, epoch / max(1, self.warmup_epochs))
        kl1_scale = warm * self.max_beta * self.kl_scales["conv1"]
        kl2_scale = warm * self.max_beta * self.kl_scales["conv2"]
        kl3_scale = warm * self.max_beta * self.kl_scales["fc1"]
        kl4_scale = warm * self.max_beta * self.kl_scales["fc2"]
        kl5_scale = warm * self.max_beta * self.kl_scales["fc3"]

        h, kl1 = self.conv1(x, kl1_scale)
        h = F.relu(h)
        h = F.avg_pool2d(h, 2)  # 24 -> 12

        h, kl2 = self.conv2(h, kl2_scale)
        h = F.relu(h)
        h = F.avg_pool2d(h, 2)  # 12 -> 6 -> flatten 4x4 expected

        h = h.view(h.size(0), -1)
        h, kl3 = self.fc1(h, kl3_scale); h = F.relu(h)
        h, kl4 = self.fc2(h, kl4_scale); h = F.relu(h)
        out, kl5 = self.fc3(h, kl5_scale)
        return out, kl1 + kl2 + kl3 + kl4 + kl5

# ---------- Data loaders ----------
DATA_ROOT = os.path.abspath("./data")
def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
    train_full = MNIST(DATA_ROOT, train=True,  download=True, transform=transform)
    train, valid = torch.utils.data.random_split(
        train_full, [55_000, 5_000], generator=torch.Generator().manual_seed(0)
    )
    test = MNIST(DATA_ROOT, train=False, download=True, transform=transform)
    dl_kwargs = dict(batch_size=batch_size, num_workers=4, pin_memory=torch.cuda.is_available(), persistent_workers=True)
    return DataLoader(train, shuffle=True, **dl_kwargs), DataLoader(valid, shuffle=False, **dl_kwargs), DataLoader(test, shuffle=False, **dl_kwargs)

# ---------- Mask application (fallback to ones if no masks found) ----------
MASKS_PATH = "../tests/LeNet_MNIST/99_test1_various_masks/mask_1.1_size.npy"
try:
    raw_masks = list(np.load(MASKS_PATH, allow_pickle=True))
except Exception:
    raw_masks = []

def apply_masks(model: nn.Module, device: torch.device):
    # tries to match shapes in order conv1, conv2, fc1, fc2, fc3; otherwise uses ones
    layer_names = ["conv1", "conv2", "fc1", "fc2", "fc3"]
    candidates = [np.asarray(m) for m in raw_masks] if raw_masks else []
    for ln in layer_names:
        layer = getattr(model, ln)
        target_shape = tuple(layer.mask_w.shape)
        found = False
        for m in candidates:
            if tuple(m.shape) == target_shape:
                layer.mask_w.data.copy_(torch.tensor(m, dtype=torch.float32, device=device))
                layer.mask_b.data.copy_(torch.ones_like(layer.mask_b))
                found = True
                break
        if not found:
            layer.mask_w.data.fill_(1.0)
            layer.mask_b.data.fill_(1.0)

# ------------------------ Tune trainable ------------------------
def train_and_evaluate(config: dict):
    """Ray Tune trainable (used for PBT). Reports metrics to tune."""
    ckpt = session.get_checkpoint()
    seed = int(config.get("seed", 1234))
    best_val_acc = 0.0
    start_epoch = 1

    device = get_device()
    set_global_seed(seed)

    train_ld, val_ld, test_ld = get_data_loaders(int(config["batch_size"]))
    N = len(train_ld.dataset)

    model = MaskedBayesianLeNet5(config).to(device)
    apply_masks(model, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
    scheduler = MultiStepLR(optimizer, milestones=[int(0.8 * config["epochs"]), int(0.9 * config["epochs"])], gamma=0.2)

    use_amp = bool(config.get("amp", True)) and torch.cuda.is_available()
    try:
        scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
        autocast_ctx = lambda: torch.amp.autocast("cuda")
    except Exception:
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
        autocast_ctx = lambda: torch.cuda.amp.autocast()

    # restore if checkpoint
    if ckpt is not None:
        with ckpt.as_directory() as ckdir:
            p = os.path.join(ckdir, "checkpoint.pt")
            if os.path.exists(p):
                state = torch.load(p, map_location="cpu")
                model.load_state_dict(state["model"])
                optimizer.load_state_dict(state["optimizer"])
                scheduler.load_state_dict(state["scheduler"])
                best_val_acc = state.get("best_val_acc", 0.0)
                start_epoch = state.get("epoch", 0) + 1
                for st in optimizer.state.values():
                    for k, v in st.items():
                        if torch.is_tensor(v):
                            st[k] = v.to(device)

    epochs = int(config["epochs"])

    for epoch in range(start_epoch, epochs + 1):
        model.train()
        ce_sum = 0.0
        kl_sum = 0.0
        for x, y in train_ld:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            if use_amp:
                with autocast_ctx():
                    logits, kl = model(x, epoch)
                    ce = F.cross_entropy(logits, y, reduction="sum")
                    loss = ce + kl / N
                scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            else:
                logits, kl = model(x, epoch)
                ce = F.cross_entropy(logits, y, reduction="sum")
                loss = ce + kl / N
                loss.backward(); optimizer.step()
            ce_sum += ce.item()
            kl_sum += float(kl.item())

        scheduler.step()
        train_ce = ce_sum / N

        # validation
        model.eval()
        correct = 0
        val_ce_sum = 0.0
        with torch.no_grad():
            for x, y in val_ld:
                x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
                if use_amp:
                    with autocast_ctx():
                        logits, _ = model(x, epoch)
                        val_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                else:
                    logits, _ = model(x, epoch)
                    val_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                correct += (logits.argmax(1) == y).sum().item()
        val_acc = correct / len(val_ld.dataset)
        val_ce = val_ce_sum / len(val_ld.dataset)
        best_val_acc = max(best_val_acc, val_acc)

        payload = {
            "epoch": epoch,
            "train_ce": train_ce,
            "val_ce": val_ce,
            "mean_accuracy": val_acc,
            "best_mean_accuracy": best_val_acc,
        }

        # attach test metrics on final epoch if available
        if epoch == epochs:
            test_correct = 0
            test_ce_sum = 0.0
            model.eval()
            with torch.no_grad():
                for x, y in test_ld:
                    x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
                    if use_amp:
                        with autocast_ctx():
                            logits, _ = model(x, epochs)
                            test_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                    else:
                        logits, _ = model(x, epochs)
                        test_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                    test_correct += (logits.argmax(1) == y).sum().item()
            payload["test_accuracy"] = test_correct / len(test_ld.dataset)
            payload["test_ce"] = test_ce_sum / len(test_ld.dataset)

        # Save checkpoint occasionally (so PBT can exploit/perturb)
        if Checkpoint is not None and (epoch % 5 == 0 or epoch == epochs):
            with tempfile.TemporaryDirectory() as tmp:
                torch.save({"model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "best_val_acc": best_val_acc, "epoch": epoch, "config": config}, os.path.join(tmp, "checkpoint.pt"))
                session.report(payload, checkpoint=Checkpoint.from_directory(tmp))
        else:
            session.report(payload)

# ------------------------ Single retrain function (no Ray) ------------------------
def train_single_eval(config: dict, seed: int) -> float:
    """Train fresh model with given config and seed, return test accuracy."""
    device = get_device()
    set_global_seed(seed)

    train_ld, val_ld, test_ld = get_data_loaders(int(config["batch_size"]))
    N = len(train_ld.dataset)

    model = MaskedBayesianLeNet5(config).to(device)
    apply_masks(model, device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
    scheduler = MultiStepLR(optimizer, milestones=[int(0.8 * config["epochs"]), int(0.9 * config["epochs"])], gamma=0.2)

    use_amp = bool(config.get("amp", True)) and torch.cuda.is_available()
    try:
        scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
        autocast_ctx = lambda: torch.amp.autocast("cuda")
    except Exception:
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
        autocast_ctx = lambda: torch.cuda.amp.autocast()

    epochs = int(config["epochs"])
    for epoch in range(1, epochs + 1):
        model.train()
        ce_sum = 0.0
        for x, y in train_ld:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            if use_amp:
                with autocast_ctx():
                    logits, kl = model(x, epoch)
                    ce = F.cross_entropy(logits, y, reduction="sum")
                    loss = ce + kl / N
                scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            else:
                logits, kl = model(x, epoch)
                ce = F.cross_entropy(logits, y, reduction="sum")
                loss = ce + kl / N
                loss.backward(); optimizer.step()
            ce_sum += ce.item()
        scheduler.step()

    # final test
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_ld:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            if use_amp:
                with autocast_ctx():
                    logits, _ = model(x, epochs)
            else:
                logits, _ = model(x, epochs)
            correct += (logits.argmax(1) == y).sum().item()
    return correct / len(test_ld.dataset)

# ---------------------------------- Main -----------------------------------
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-samples", type=int, default=48)   # number of PBT trials
    parser.add_argument("--max-concurrent", type=int, default=6)
    args = parser.parse_args()

    # ----------------- 1) Run PBT once to get best hyperparams -----------------
    ray.init(ignore_reinit_error=True, include_dashboard=False)

    pbt = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="mean_accuracy",
        mode="max",
        perturbation_interval=5,
        hyperparam_mutations={
            "lr": lambda: 10 ** random.uniform(-4, -2),
            "prior_std": lambda: random.uniform(0.05, 0.5),
            "max_beta": lambda: random.uniform(0.05, 0.5),
            "warmup_epochs": lambda: random.randint(5, 50),
            "kl_scale_conv1": lambda: random.uniform(0.01, 2.0),
            "kl_scale_conv2": lambda: random.uniform(0.01, 2.0),
            "kl_scale_fc1": lambda: random.uniform(0.1, 2.0),
            "kl_scale_fc2": lambda: random.uniform(0.1, 2.0),
            "kl_scale_fc3": lambda: random.uniform(0.01, 0.5),
            "batch_size": lambda: random.choice([64, 128, 256]),
            "seed": lambda: random.randint(1, 10_000),
            "amp": lambda: random.choice([True, False]),
        },
    )

    search_space = {
        "lr": tune.loguniform(1e-4, 1e-2),
        "prior_std": tune.uniform(0.05, 0.5),
        "max_beta": tune.uniform(0.05, 0.5),
        "warmup_epochs": tune.randint(5, 50),
        "kl_scale_conv1": tune.uniform(0.01, 2.0),
        "kl_scale_conv2": tune.uniform(0.01, 2.0),
        "kl_scale_fc1": tune.uniform(0.1, 2.0),
        "kl_scale_fc2": tune.uniform(0.1, 2.0),
        "kl_scale_fc3": tune.uniform(0.01, 0.5),
        "batch_size": tune.choice([64, 128, 256]),
        "epochs": 60,   # PBT runs for 60 epochs by default (you can change)
        "seed": tune.randint(1, 10_000),
        "amp": tune.choice([True, False]),
    }

    print("[PBT] starting hyperparameter search...")
    analysis = tune.run(
        train_and_evaluate,
        name="pbt_mnist_lenet5",
        scheduler=pbt,
        num_samples=args.num_samples,
        resources_per_trial={"cpu": 2, "gpu": 1},
        config=search_space,
        stop={"training_iteration": 60},
        max_concurrent_trials=args.max_concurrent,
        reuse_actors=False,
        verbose=1,
        raise_on_failed_trial=False,
        fail_fast=False,
        max_failures=2,
    )

    # get best trial (prefer test_accuracy)
    best_trial = analysis.get_best_trial("test_accuracy", "max", "last")
    if best_trial is None or best_trial.last_result.get("test_accuracy") is None:
        print("[WARN] falling back to best mean_accuracy")
        best_trial = analysis.get_best_trial("mean_accuracy", "max", "last")
    if best_trial is None:
        raise SystemExit("No successful trials from PBT - cannot continue.")

    best_cfg = dict(best_trial.config)
    print("\n[PBT] Best config found (from best trial):")
    print(best_cfg)
    # Shutdown Ray before retraining for seeds (not needed for single-process retraining)
    ray.shutdown()

    # ----------------- 2) Retrain best config for 5 seeds (50 epochs) -----------------
    # Force epochs=50 as requested
    best_cfg["epochs"] = 50
    # keep deterministic choices stable if needed (but seed will vary per retrain)
    seeds = [42, 420, 90210, 1492, 911]
    per_seed_results: List[float] = []
    for s in seeds:
        print(f"\n[Retrain] seed={s} with best cfg (epochs=50)")
        acc = train_single_eval(best_cfg, seed=s)
        print(f" -> test accuracy: {acc:.4f}")
        per_seed_results.append(acc)

    mean_acc = float(np.mean(per_seed_results))
    std_acc = float(np.std(per_seed_results))
    print(f"\nFinal (5 seeds) using PBT-best hyperparams (epochs=50): mean={mean_acc:.4f}  std={std_acc:.4f}")
