#!/usr/bin/env python3
# MNIST_Lenet300_SVI_PBT.py
# PBT of a masked Bayesian LeNet-300-100 on MNIST
# Parallelized across 8 GPUs with Ray Tune + AIR-style checkpointing.
# Reports cross-entropy (train/val) and accuracy; then retrains the best-test config on 5 seeds.

import os
import random
import tempfile
import numpy as np
from typing import Dict, Optional, Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining

# --- Ray AIR imports (version-compatible) ---
try:
    from ray.air import session
except Exception:  # very old Ray
    from ray.tune import session  # type: ignore
try:
    from ray.air.checkpoint import Checkpoint  # Ray >= 2.3 usually
except Exception:  # other Ray builds
    try:
        from ray.tune import Checkpoint  # modern deprecation path wants this
    except Exception:
        try:
            from ray.train import Checkpoint  # fallback
        except Exception:
            Checkpoint = None  # no checkpoint API available

# ---------- Perf / determinism ----------
torch.backends.cudnn.benchmark = True
torch.set_num_threads(1)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("TORCHVISION_DISABLE_DOWNLOAD_PROGRESS", "1")

# ---------- Helpers ----------
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ---------- Masks ----------
MASKS_PATH = "../tests/LeNet_MNIST/99_test1_various_masks/mask_1.1_size.npy"
raw_masks = np.load(MASKS_PATH, allow_pickle=True)
weight_masks_cpu = [torch.tensor(m, dtype=torch.float32) for m in raw_masks]
layers_names = ["fc1", "fc2", "fc3"]
bias_sizes: Dict[str, int] = {"fc1": 300, "fc2": 100, "fc3": 10}

def kl_gaussian(mu: torch.Tensor, rho: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    sigma = torch.log1p(torch.exp(rho))
    kl_elem = sigma.pow(2) + mu.pow(2) - 1.0 - 2.0 * torch.log(sigma + 1e-12)
    return 0.5 * torch.sum(kl_elem * mask)

class MaskedLRDense(nn.Module):
    def __init__(self, fin: int, fout: int):
        super().__init__()
        self.mu_w  = nn.Parameter(0.05 * torch.randn(fout, fin))
        self.rho_w = nn.Parameter(-3.0 * torch.ones(fout, fin))
        self.mu_b  = nn.Parameter(torch.zeros(fout))
        self.rho_b = nn.Parameter(-3.0 * torch.ones(fout))
        self.register_buffer("mask_w", torch.ones(fout, fin))
        self.register_buffer("mask_b", torch.ones(fout))

    def forward(self, x: torch.Tensor, kl_scale: float):
        mu_w = self.mu_w * self.mask_w
        mu_b = self.mu_b * self.mask_b
        sigma_w = torch.log1p(torch.exp(self.rho_w))
        sigma_b = torch.log1p(torch.exp(self.rho_b))

        mu_z  = F.linear(x, mu_w, mu_b)
        var_z = F.linear(x.pow(2), (sigma_w ** 2) * (self.mask_w ** 2)) + (sigma_b ** 2)
        eps   = torch.randn_like(mu_z)
        z     = mu_z + eps * torch.sqrt(var_z + 1e-8)

        kl = kl_gaussian(self.mu_w, self.rho_w, self.mask_w) + \
             kl_gaussian(self.mu_b, self.rho_b, self.mask_b)
        return z, kl * kl_scale

INPUT_DIM = 28 * 28

class MaskedBayesianLeNet(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.fc1 = MaskedLRDense(INPUT_DIM, 300)
        self.fc2 = MaskedLRDense(300, 100)
        self.fc3 = MaskedLRDense(100, 10)

        self.max_beta      = float(config["max_beta"])
        self.warmup_epochs = int(config["warmup_epochs"])
        self.kl_scales = {
            "fc1": float(config["kl_scale_fc1"]),
            "fc2": float(config["kl_scale_fc2"]),
            "fc3": float(config["kl_scale_fc3"]),
        }

    def forward(self, x: torch.Tensor, epoch: int):
        x = x.view(x.size(0), -1)
        warm = min(1.0, epoch / max(1, self.warmup_epochs))
        kl1_scale = warm * self.max_beta * self.kl_scales["fc1"]
        kl2_scale = warm * self.max_beta * self.kl_scales["fc2"]
        kl3_scale = warm * self.max_beta * self.kl_scales["fc3"]

        h1, kl1 = self.fc1(x, kl1_scale); h1 = F.relu(h1)
        h2, kl2 = self.fc2(h1, kl2_scale); h2 = F.relu(h2)
        out, kl3 = self.fc3(h2, kl3_scale)
        return out, kl1 + kl2 + kl3

DATA_ROOT = os.path.abspath("./data")

def get_data_loaders(batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
    train_full = MNIST(DATA_ROOT, train=True,  download=True, transform=transform)
    train, valid = torch.utils.data.random_split(
        train_full, [55_000, 5_000], generator=torch.Generator().manual_seed(0)
    )
    test = MNIST(DATA_ROOT, train=False, download=True, transform=transform)

    dl_kwargs = dict(
        batch_size=batch_size,
        num_workers=4,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=True,
    )
    return (
        DataLoader(train, shuffle=True,  **dl_kwargs),
        DataLoader(valid, shuffle=False, **dl_kwargs),
        DataLoader(test,  shuffle=False, **dl_kwargs),
    )

def apply_masks(model: nn.Module, device: torch.device):
    for ln, mask_cpu in zip(layers_names, weight_masks_cpu):
        layer = getattr(model, ln)
        layer.mask_w.data.copy_(mask_cpu.to(device))
        layer.mask_b.data.copy_(torch.ones(bias_sizes[ln], dtype=torch.float32, device=device))

# ------------------------ Tune trainable (reports CE + accuracy) ------------------------
def train_and_evaluate(config: dict):
    """Ray Tune function trainable with AIR-style restore/save. Reports CE and accuracy."""
    ckpt = session.get_checkpoint()
    seed = int(config.get("seed", 1234))
    best_val_acc = 0.0
    start_epoch = 1

    device = get_device()
    set_global_seed(seed)

    train_ld, val_ld, test_ld = get_data_loaders(int(config["batch_size"]))
    N = len(train_ld.dataset)

    model = MaskedBayesianLeNet(config).to(device)
    apply_masks(model, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
    scheduler = MultiStepLR(
        optimizer,
        milestones=[int(0.8 * config["epochs"]), int(0.9 * config["epochs"])],
        gamma=0.2,
    )

    use_amp = bool(config.get("amp", True)) and torch.cuda.is_available()
    try:
        scaler = torch.amp.GradScaler("cuda", enabled=use_amp)  # torch>=2.1
        autocast_ctx = lambda: torch.amp.autocast("cuda")
    except Exception:
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)     # fallback
        autocast_ctx = lambda: torch.cuda.amp.autocast()

    # ----- restore -----
    if ckpt is not None:
        with ckpt.as_directory() as ckdir:
            p = os.path.join(ckdir, "checkpoint.pt")
            if os.path.exists(p):
                state = torch.load(p, map_location="cpu")
                model.load_state_dict(state["model"])
                optimizer.load_state_dict(state["optimizer"])
                scheduler.load_state_dict(state["scheduler"])
                best_val_acc = state.get("best_val_acc", 0.0)
                start_epoch = state.get("epoch", 0) + 1
                seed = int(state.get("seed", seed))
                for st in optimizer.state.values():
                    for k, v in st.items():
                        if torch.is_tensor(v):
                            st[k] = v.to(device)

    epochs = int(config["epochs"])

    for epoch in range(start_epoch, epochs + 1):
        # ------------------------------- Train ------------------------------
        model.train()
        ce_sum = 0.0
        kl_sum = 0.0

        for x, y in train_ld:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            if use_amp:
                with autocast_ctx():
                    logits, kl = model(x, epoch)
                    ce = F.cross_entropy(logits, y, reduction="sum")
                    loss = ce + kl / N
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits, kl = model(x, epoch)
                ce = F.cross_entropy(logits, y, reduction="sum")
                loss = ce + kl / N
                loss.backward()
                optimizer.step()

            ce_sum += ce.item()
            kl_sum += float(kl.item())

        scheduler.step()
        train_ce = ce_sum / N  # average per-sample cross-entropy

        # ------------------------------ Validate ---------------------------
        model.eval()
        correct = 0
        val_ce_sum = 0.0
        with torch.no_grad():
            for x, y in val_ld:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                if use_amp:
                    with autocast_ctx():
                        logits, _ = model(x, epoch)
                        val_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                else:
                    logits, _ = model(x, epoch)
                    val_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                correct += (logits.argmax(1) == y).sum().item()

        val_acc = correct / len(val_ld.dataset)
        best_val_acc = max(best_val_acc, val_acc)
        val_ce = val_ce_sum / len(val_ld.dataset)

                # --------------------------- Report & Checkpoint --------------------
        payload = {
            "epoch": epoch,
            "train_ce": train_ce,
            "val_ce": val_ce,
            "mean_accuracy": val_acc,
            "best_mean_accuracy": best_val_acc,
        }

        # 👉 On the final epoch, compute & attach test metrics so they're present
        if epoch == epochs:
            test_correct = 0
            test_ce_sum = 0.0
            model.eval()
            with torch.no_grad():
                for x, y in test_ld:
                    x = x.to(device, non_blocking=True)
                    y = y.to(device, non_blocking=True)
                    if use_amp:
                        with autocast_ctx():
                            logits, _ = model(x, epochs)
                            test_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                    else:
                        logits, _ = model(x, epochs)
                        test_ce_sum += F.cross_entropy(logits, y, reduction="sum").item()
                    test_correct += (logits.argmax(1) == y).sum().item()
            payload["test_accuracy"] = test_correct / len(test_ld.dataset)
            payload["test_ce"] = test_ce_sum / len(test_ld.dataset)

        if Checkpoint is not None and (epoch % 5 == 0 or epoch == epochs):
            with tempfile.TemporaryDirectory() as tmp:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "best_val_acc": best_val_acc,
                        "epoch": epoch,
                        "seed": seed,
                        "config": config,
                    },
                    os.path.join(tmp, "checkpoint.pt"),
                )
                session.report(payload, checkpoint=Checkpoint.from_directory(tmp))
        else:
            session.report(payload)



# ------------------------ Re-train from scratch on 5 seeds ------------------------
def train_eval_single(config: dict, seed: int) -> float:
    """Train fresh model on given seed and return test accuracy. Reports CE to console."""
    device = get_device()
    set_global_seed(seed)

    train_ld, val_ld, test_ld = get_data_loaders(int(config["batch_size"]))
    N = len(train_ld.dataset)

    model = MaskedBayesianLeNet(config).to(device)
    apply_masks(model, device)

    optimizer = torch.optim.Adam(model.parameters(), lr=float(config["lr"]))
    scheduler = MultiStepLR(
        optimizer,
        milestones=[int(0.8 * config["epochs"]), int(0.9 * config["epochs"])],
        gamma=0.2,
    )

    use_amp = bool(config.get("amp", True)) and torch.cuda.is_available()
    try:
        scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
        autocast_ctx = lambda: torch.amp.autocast("cuda")
    except Exception:
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
        autocast_ctx = lambda: torch.cuda.amp.autocast()

    epochs = int(config["epochs"])

    for epoch in range(1, epochs + 1):
        model.train()
        ce_sum = 0.0
        for x, y in train_ld:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            if use_amp:
                with autocast_ctx():
                    logits, kl = model(x, epoch)
                    ce = F.cross_entropy(logits, y, reduction="sum")
                    loss = ce + kl / N
                scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            else:
                logits, kl = model(x, epoch)
                ce = F.cross_entropy(logits, y, reduction="sum")
                loss = ce + kl / N
                loss.backward(); optimizer.step()
            ce_sum += ce.item()
        scheduler.step()

    # Test
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in test_ld:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            if use_amp:
                with autocast_ctx():
                    logits, _ = model(x, epochs)
            else:
                logits, _ = model(x, epochs)
            correct += (logits.argmax(1) == y).sum().item()
    return correct / len(test_ld.dataset) , loss

# ---------------------------------- Main -----------------------------------
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-samples", type=int, default=64)
    parser.add_argument("--max-concurrent", type=int, default=8)
    args = parser.parse_args()


        


    # Retrain from scratch on the requested 5 seeds
    seeds = [42, 420, 90210, 1492, 911]
    print("\n[Retrain best config on 5 seeds]")
    per_seed_acc: List[float] = []
    per_seed_ce: List[float] = []
    for s in seeds:
        ray.init(include_dashboard=False, ignore_reinit_error=True)

        pbt = PopulationBasedTraining(
            time_attr="training_iteration",
            metric="mean_accuracy",
            mode="max",
            perturbation_interval=5,
            hyperparam_mutations={
                "lr": lambda: 10 ** random.uniform(-4, -2),
                "prior_std": lambda: random.uniform(0.05, 0.5),
                "max_beta": lambda: random.uniform(0.05, 0.5),
                "warmup_epochs": lambda: random.randint(5, 50),
                "kl_scale_fc1": lambda: random.uniform(0.1, 2.0),
                "kl_scale_fc2": lambda: random.uniform(0.1, 2.0),
                "kl_scale_fc3": lambda: random.uniform(0.01, 0.5),
                "batch_size": lambda: random.choice([64, 128, 256]),
                "seed": lambda: random.randint(1, 10_000),
                "amp": lambda: random.choice([True, False]),
            },
        )

        search_space = {
            "lr": tune.loguniform(1e-4, 1e-2),
            "prior_std": tune.uniform(0.05, 0.5),
            "max_beta": tune.uniform(0.05, 0.5),
            "warmup_epochs": tune.randint(5, 50),
            "kl_scale_fc1": tune.uniform(0.1, 2.0),
            "kl_scale_fc2": tune.uniform(0.1, 2.0),
            "kl_scale_fc3": tune.uniform(0.01, 0.5),
            "batch_size": tune.choice([64, 128, 256]),
            "epochs": 50,
            "seed": tune.randint(1, 10_000),
            "amp": tune.choice([True, False]),
        }

        result = tune.run(
            train_and_evaluate,
            name="pbt_mnist_bnn",
            scheduler=pbt,
            num_samples=args.num_samples,
            resources_per_trial={"cpu": 2, "gpu": 1},
            config=search_space,
            stop={"training_iteration": 50},
            max_concurrent_trials=args.max_concurrent,
            reuse_actors=False,
            verbose=1,
            raise_on_failed_trial=False,  # ← don’t throw if a trial fails
            fail_fast=False,  # ← keep going
            max_failures=2,  # ← auto-retry a flaky trial
        )

        # Prefer test_accuracy if available, otherwise fall back to mean_accuracy
        # Prefer test_accuracy if available, otherwise fall back to mean_accuracy
        best_trial = result.get_best_trial("test_accuracy", "max", "last")
        if best_trial is None or best_trial.last_result.get("test_accuracy") is None:
            print("[WARN] No trials reported test_accuracy; falling back to mean_accuracy.")
            best_trial = result.get_best_trial("mean_accuracy", "max", "last")

        if best_trial is None:
            print("[ERROR] No completed trials with usable metrics. Trial summaries:")
            for t in result.trials:
                print(f"- {t.trial_id} status={t.status} last_result_keys={list((t.last_result or {}).keys())}")
            raise SystemExit(1)

        best_cfg = best_trial.config
        print("\n[Best trial]")
        print("Config:", best_cfg)
        print("Val acc (last):", best_trial.last_result.get("mean_accuracy"))
        print("Test acc (last):", best_trial.last_result.get("test_accuracy"))
        per_seed_acc.append(best_trial.last_result.get("test_accuracy"))
        per_seed_ce.append(best_trial.last_result.get("test_ce"))
    mean_acc = float(np.mean(per_seed_acc))
    std_acc  = float(np.std(per_seed_acc))
    ce_acc = float(np.mean(per_seed_ce))
    print(f"\nFinal (5 seeds): mean={mean_acc:.4f}  std={std_acc:.4f} ce={ce_acc:.4f}")

