#!/usr/bin/env python3
# bayes_lenet300100_masked.py
#
# Bayesian LeNet‑300‑100 on CIFAR‑10 with weight masks + KL / N scaling
# Does not search for hyperparams, this needs hyperparams input into it
# --------------------------------------------------------------------
import os, argparse, math, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision.transforms as T, torchvision.datasets as dset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets import MNIST

#!/usr/bin/env python3
# Lenet300_CIFAR_your_pick_hyperparameters.py
#
# Population-Based Training of a masked Bayesian LeNet-300-100 on MNIST
# using Ray Tune’s PBT scheduler. Hyperparameters:
#   • learning rate
#   • prior_std
#   • max_beta (KL weight cap)
#   • warmup_epochs (KL warm-up length)
#   • kl_scale_fc1, kl_scale_fc2, kl_scale_fc3 (per-layer KL scaling)
#   • batch_size
#
# Drop this file into your project, install ray[tune]:
#   pip install "ray[tune]"
#
# Then run:
#   ray start --head
#   python Lenet300_CIFAR_your_pick_hyperparameters.py
# --------------------------------------------------------------------------

import os
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

from ray import tune
from ray.tune.schedulers import PopulationBasedTraining


'''
This is an example from a run we did. Feel free to use different ones. 
Best trial config: {'lr': 0.003089884483769271, 'prior_std': 0.41480028624720355, 'max_beta': 0.19515047577824146, 'warmup_epochs': 24,
 'kl_scale_fc1': 1.8653370380077794, 'kl_scale_fc2': 1.3930227773466246, 'kl_scale_fc3': 1.2410948125831944, 'batch_size': 64, 'epochs': 100}
Best validation accuracy: 0.5528
'''

seeds = [42, 420, 90210, 1492, 911]
mnist_results_acc = []
cifar_results_acc = []

for SEED in seeds:
    print(SEED)

    # Remove this section if you want true and consistent randomness between the full and pruned models.

    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    # if you’re also using CUDA:
    torch.cuda.manual_seed_all(SEED)



    # ─── Masks (3 layers: fc1, fc2, fc3) ─────────────────────────────────
    raw           = np.load("../tests/LeNet_CIFAR/99_test1_various_masks/mask_1.1_size.npy", allow_pickle=True)  # list len 3
    weight_masks  = [torch.tensor(m, dtype=torch.float32, device=DEVICE) for m in raw]
    layers_names  = ["fc1", "fc2", "fc3"]
    bias_sizes    = {"fc1": 300, "fc2": 100, "fc3": 10}               # vector masks

    # ─── Hyper‑parameters ────────────────────────────────────────────────
    PRIOR_STD        = PRIOR_STD
    KL_MAX_BETA      = MAX_BETA      # warm‑up target
    KL_WARM_EPOCHS   = WARMUP_EPOCHS
    MC_SAMPLES_TEST  = 100
    LR_STEPS         = [96, 108]     # for 120‑epoch run

    # ─── Helper: KL for masked Gaussian ──────────────────────────────────
    def kl_gaussian(mu, log_sigma, mask):
        # element‑wise expression then masked sum
        kl_elem = torch.exp(2*log_sigma) + mu**2 - 1.0 - 2*log_sigma
        return 0.5 * torch.sum(kl_elem * mask)

    # ─── Mask‑aware BBB layer with local reparam ─────────────────────────
    class MaskedLRDense(nn.Module):
        def __init__(self, fin, fout, prior_std=PRIOR_STD):
            super().__init__()
            self.mu_w   = nn.Parameter(0.05*torch.randn(fout, fin))
            self.rho_w  = nn.Parameter(-3.*torch.ones(fout, fin))
            self.mu_b   = nn.Parameter(torch.zeros(fout))
            self.rho_b  = nn.Parameter(-3.*torch.ones(fout))
            # buffers for binary masks (set later)
            self.register_buffer("mask_w", torch.ones(fout, fin))
            self.register_buffer("mask_b", torch.ones(fout))
            self.prior_std = prior_std

        def forward(self, x):
            sigma_w = torch.log1p(torch.exp(self.rho_w))
            sigma_b = torch.log1p(torch.exp(self.rho_b))

            # local‑reparam: sample activations not weights
            mu_z  = F.linear(x, self.mu_w * self.mask_w, self.mu_b * self.mask_b)
            var_z = F.linear(x.pow(2),
                             (sigma_w**2) * self.mask_w**2) + (sigma_b**2) * self.mask_b**2
            eps   = torch.randn_like(mu_z)
            z     = mu_z + eps * torch.sqrt(var_z + 1e-8)

            kl = kl_gaussian(self.mu_w, self.rho_w, self.mask_w) + \
                 kl_gaussian(self.mu_b, self.rho_b, self.mask_b)
            return z, kl

    # ─── Pure FC LeNet‑300‑100 with masks ────────────────────────────────
    INPUT_DIM = 32*32*3   # 3072  ← matches your mask

    class MaskedBayesianLeNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = MaskedLRDense(INPUT_DIM, 300)
            self.fc2 = MaskedLRDense(300, 100)
            self.fc3 = MaskedLRDense(100, 10)

        def forward(self, x):
            x = x.view(x.size(0), -1)          # flatten raw image
            h1, kl1 = self.fc1(x); h1 = F.relu(h1)
            h2, kl2 = self.fc2(h1); h2 = F.relu(h2)
            out, kl3= self.fc3(h2)
            return out, (KL_SCALE_FC1*kl1 + KL_SCALE_FC2*kl2 + KL_SCALE_FC3*kl3) # smaller β on last layer




    # ─── CIFAR‑10 loaders with augmentation ──────────────────────────────
    def get_loaders(bs):
        aug   = T.Compose([T.RandomCrop(32,4),
                           T.RandomHorizontalFlip(),
                           T.ToTensor(),
                           T.Normalize((0.5,)*3, (0.5,)*3)])

        plain = T.Compose([T.ToTensor(),
                           T.Normalize((0.5,)*3, (0.5,)*3)])

        tr = dset.CIFAR10(root="./data",
                          train=True,
                          transform=aug,
                          download=True)

        va = dset.CIFAR10(root="./data",
                          train=True,
                          transform=plain,
                          download=True)

        va, _ = torch.utils.data.random_split(
            va, [5000, len(va)-5000],
            generator=torch.Generator().manual_seed(0))

        te = dset.CIFAR10(root="./data",
                          train=False,
                          transform=plain,
                          download=True)

        kwargs = dict(batch_size=bs, num_workers=2, pin_memory=True)
        return (DataLoader(tr, shuffle=True, **kwargs),
                DataLoader(va, shuffle=False, **kwargs),
                DataLoader(te, shuffle=False, **kwargs))


    # ─── Apply masks to network after construction ───────────────────────
    def apply_masks(model):
        for ln, wm in zip(layers_names, weight_masks):
            layer = getattr(model, ln)
            layer.mask_w.data = wm.to(DEVICE)
            layer.mask_b.data = torch.ones(bias_sizes[ln], device=DEVICE)

    # ─── Training / evaluation ───────────────────────────────────────────
    def beta(epoch):                   # KL warm‑up
        return KL_MAX_BETA * min(1.0, epoch / KL_WARM_EPOCHS)

    def train_epoch(net, loader, opt, epoch, N):
        net.train(); ce_tot=kl_tot=0
        β = beta(epoch)
        for x,y in loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            logits, kl = net(x)              # kl sum over weights
            ce = F.cross_entropy(logits, y, reduction='sum')
            loss = ce + β * kl / N           # KL scaled per‑sample
            loss.backward(); opt.step()
            ce_tot += ce.item(); kl_tot += kl.item()
        return ce_tot/N, kl_tot/N, β

    @torch.no_grad()
    def mc_accuracy(net, loader, mc=MC_SAMPLES_TEST):
        net.eval(); correct=total=0
        for x,y in loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            logits = 0
            for _ in range(mc):
                logits += net(x)[0]
            correct += (logits.argmax(1)==y).sum().item()
            total   += y.size(0)
        return correct/total

    # ─── Main entry ───────────────────────────────────────────────────────
    def main():
        p = argparse.ArgumentParser()
        p.add_argument('--epochs', type=int, default=100)
        p.add_argument('--batch',  type=int, default=BATCH_SIZE)
        p.add_argument('--lr',    type=float, default=LR)
        args = p.parse_args()

        train_ld, val_ld, test_ld = get_loaders(args.batch)
        N = len(train_ld.dataset)

        net = MaskedBayesianLeNet().to(DEVICE)
        apply_masks(net)                       # ← masks injected here
        opt = torch.optim.Adam(net.parameters(), lr=args.lr)
        sched = MultiStepLR(opt, milestones=LR_STEPS, gamma=0.2)

        best = 0
        for ep in range(1, args.epochs+1):
            ce, kl, β = train_epoch(net, train_ld, opt, ep, N)
            val_acc   = mc_accuracy(net, val_ld)
            best      = max(best, val_acc)
            sched.step()
            print(f"E{ep:03d}  CE {ce:.3f}  KL {kl:.3f}  β {β:.2f}  "
                  f"val {val_acc*100:.2f}%  best {best*100:.2f}%")

        print(f"Test accuracy: {mc_accuracy(net, test_ld)*100:.2f}%")
        final_acc = mc_accuracy(net, test_ld)*100
        cifar_results_acc.append(mc_accuracy)

    if __name__ == "__main__":
        main()

