#!/usr/bin/env python3
# CIFAR_Resnet_SVI_mask_only.py
# Standard SVI (unit KL weights) for a mask-driven Bayesian ResNet on CIFAR-10.

import os, math, random, argparse
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CIFAR10

# -------------------- Repro/Perf --------------------
torch.backends.cudnn.benchmark = True
os.environ.setdefault("TORCHVISION_DISABLE_DOWNLOAD_PROGRESS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# -------------------- CIFAR loaders --------------------
def cifar10_loaders(data_root: str, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    transform_train = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465),
                    (0.2023, 0.1994, 0.2010)),
    ])

    # Only download if not present
    cache_dir = os.path.join(data_root, "cifar-10-batches-py")
    download_flag = not os.path.exists(cache_dir)
    trainset = CIFAR10(data_root, train=True, download=download_flag, transform=transform_train)
    testset  = CIFAR10(data_root, train=False, download=download_flag, transform=transform_test)

    g = torch.Generator().manual_seed(0)
    trainset, validset = torch.utils.data.random_split(trainset, [45000, 5000], generator=g)

    dl_kwargs = dict(batch_size=batch_size, num_workers=4,
                     pin_memory=torch.cuda.is_available(),
                     persistent_workers=True)
    return (DataLoader(trainset, shuffle=True,  **dl_kwargs),
            DataLoader(validset, shuffle=False, **dl_kwargs),
            DataLoader(testset,  shuffle=False, **dl_kwargs))

# -------------------- Mask I/O --------------------
def load_mask_list(npy_path: str) -> List[np.ndarray]:
    arr = np.load(npy_path, allow_pickle=True)
    if isinstance(arr, np.ndarray) and arr.dtype == object:
        return list(arr.tolist())
    return list(arr)

# -------------------- Bayesian layers (weight sampling) --------------------
def _softplus(x: torch.Tensor) -> torch.Tensor:
    return torch.log1p(torch.exp(x))

def kl_gaussian(mu: torch.Tensor, rho: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # log q ~ N(mu, sigma^2) vs std normal prior
    sigma = _softplus(rho)
    kl_elem = sigma.pow(2) + mu.pow(2) - 1.0 - 2.0 * torch.log(sigma + 1e-12)
    return 0.5 * torch.sum(kl_elem * mask)

class BNNConv2d(nn.Module):
    """Masked Bayesian Conv with reparameterized weight sampling (memory-friendly)."""
    def __init__(self, mask_w: torch.Tensor, bias: bool=False, stride=1, padding=0):
        super().__init__()
        assert mask_w.ndim == 4, f"Conv mask must be 4D, got {mask_w.shape}"
        oc, ic, kh, kw = map(int, mask_w.shape)
        self.stride = stride
        self.padding = padding

        self.mu_w  = nn.Parameter(0.05 * torch.randn(oc, ic, kh, kw))
        self.rho_w = nn.Parameter(torch.full((oc, ic, kh, kw), -3.0))
        self.bias  = nn.Parameter(torch.zeros(oc)) if bias else None
        self.register_buffer("mask_w", mask_w.to(torch.float32))

    @property
    def out_channels(self): return self.mu_w.shape[0]

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        sigma = _softplus(self.rho_w)
        eps   = torch.randn_like(self.mu_w)
        W     = (self.mu_w + eps * sigma) * self.mask_w
        y     = F.conv2d(x, W, bias=self.bias, stride=self.stride, padding=self.padding)
        kl    = kl_gaussian(self.mu_w, self.rho_w, self.mask_w)
        if self.bias is not None:
            kl = kl + 0.5 * torch.sum(self.bias**2)  # std normal prior on bias
        return y, kl

class BNNLinear(nn.Module):
    """Masked Bayesian Linear with reparameterized weight sampling (memory-friendly)."""
    def __init__(self, mask_w: torch.Tensor, bias: bool=True):
        super().__init__()
        assert mask_w.ndim == 2, f"Linear mask must be 2D, got {mask_w.shape}"
        ofeats, ifeats = map(int, mask_w.shape)
        self.mu_w  = nn.Parameter(0.05 * torch.randn(ofeats, ifeats))
        self.rho_w = nn.Parameter(torch.full((ofeats, ifeats), -3.0))
        self.bias  = nn.Parameter(torch.zeros(ofeats)) if bias else None
        self.register_buffer("mask_w", mask_w.to(torch.float32))

    @property
    def out_features(self): return self.mu_w.shape[0]

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        sigma = _softplus(self.rho_w)
        eps   = torch.randn_like(self.mu_w)
        W     = (self.mu_w + eps * sigma) * self.mask_w
        y     = F.linear(x, W, self.bias)
        kl    = kl_gaussian(self.mu_w, self.rho_w, self.mask_w)
        if self.bias is not None:
            kl = kl + 0.5 * torch.sum(self.bias**2)
        return y, kl

# -------------------- Small helpers --------------------
class ChannelAdapter1x1(nn.Module):
    def __init__(self, c_in: int, c_out: int, stride=1):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, bias=False)
        self.bn   = nn.BatchNorm2d(c_out)
    def forward(self, x): return self.bn(self.conv(x))

class InputRGBAdapter(nn.Module):
    def __init__(self, c_out: int):
        super().__init__()
        self.proj = nn.Conv2d(3, c_out, kernel_size=1, bias=False)
    def forward(self, x): return self.proj(x)

class MaskSubLayer(nn.Module):
    def __init__(self, conv: BNNConv2d, adapter_in: Optional[nn.Module]):
        super().__init__()
        self.adapter_in = adapter_in
        self.conv = conv
        self.bn   = nn.BatchNorm2d(conv.out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.adapter_in is not None:
            x = self.adapter_in(x)
        y, kl = self.conv(x)
        y = self.relu(self.bn(y))
        return y, kl

class ResidualBlock(nn.Module):
    def __init__(self, sub1: MaskSubLayer, sub2: Optional[MaskSubLayer],
                 skip_c_in: int, block_out_c: int):
        super().__init__()
        self.sub1 = sub1
        self.sub2 = sub2
        self.res  = None
        if int(skip_c_in) != int(block_out_c):
            self.res = ChannelAdapter1x1(int(skip_c_in), int(block_out_c), stride=1)

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        kl = x.new_zeros(())
        y, dkl = self.sub1(x); kl = kl + dkl
        if self.sub2 is not None:
            y, dkl = self.sub2(y); kl = kl + dkl
        skip = x if self.res is None else self.res(x)
        out = torch.relu(y + skip)
        return out, kl

# -------------------- Mask-driven BNN ResNet --------------------
class MaskDrivenBNNResNet(nn.Module):
    """
    Build a ResNet-ish graph from a list of masks (npy object array):
      • leading 4D entries -> sequence of conv layers (paired into residual blocks)
      • trailing 2D entries -> linear layers after global pooling
    Unmasked 1x1 adapters are inserted as needed to match channels, and an RGB adapter
    is used if the stem conv expects !=3 input channels.
    """
    def __init__(self, masks: List[np.ndarray]):
        super().__init__()
        tlist: List[torch.Tensor] = [torch.tensor(m, dtype=torch.float32) for m in masks]

        conv_modules: List[MaskSubLayer] = []
        conv_entry_channels: List[int] = []
        current_c: Optional[int] = None
        self.input_adapter: Optional[nn.Module] = None

        idx = 0
        # ----- Conv sequence -----
        while idx < len(tlist) and tlist[idx].ndim == 4:
            m = tlist[idx]; oc, ic, kh, kw = map(int, m.shape)
            pad = (kh // 2, kw // 2)

            if current_c is None:
                if ic != 3:
                    self.input_adapter = InputRGBAdapter(ic)
                    entry_c = ic
                else:
                    entry_c = 3
                current_c = entry_c
            else:
                entry_c = current_c

            adapter_in = None
            if entry_c != ic:
                adapter_in = ChannelAdapter1x1(entry_c, ic, stride=1)
                entry_c = ic

            conv = BNNConv2d(m, bias=False, stride=1, padding=pad)
            conv_modules.append(MaskSubLayer(conv, adapter_in))
            conv_entry_channels.append(current_c)  # skip path channels
            current_c = oc
            idx += 1

        # Pair convs into residual blocks (last block may be single-conv)
        self.blocks = nn.ModuleList()
        bi = 0
        while bi < len(conv_modules):
            sub1 = conv_modules[bi]
            sub2 = conv_modules[bi+1] if (bi + 1) < len(conv_modules) else None
            skip_c_in   = int(conv_entry_channels[bi])
            block_out_c = int((sub2.conv.out_channels if sub2 else sub1.conv.out_channels))
            self.blocks.append(ResidualBlock(sub1, sub2, skip_c_in, block_out_c))
            bi += 2

        # ----- Linear sequence -----
        self.linears = nn.ModuleList()
        while idx < len(tlist):
            m = tlist[idx]
            assert m.ndim == 2, f"Remaining masks must be 2D (Linear), got {m.shape}"
            self.linears.append(BNNLinear(m, bias=True))
            idx += 1

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    @property
    def penultimate_dim(self) -> int:
        if len(self.linears) > 0:
            return self.linears[-1].out_features
        # else use channels of last block output
        last = self.blocks[-1]
        return last.sub2.conv.out_channels if last.sub2 is not None else last.sub1.conv.out_channels

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        kl_total = x.new_zeros(())
        if self.input_adapter is not None:
            x = self.input_adapter(x)
        h = x
        for blk in self.blocks:
            h, dkl = blk(h)
            kl_total = kl_total + dkl
        if len(self.linears) > 0:
            h = self.avgpool(h)
            h = torch.flatten(h, 1)
            for lin in self.linears:
                h, dkl = lin(h); kl_total = kl_total + dkl
        return h, kl_total

# -------------------- Train / Eval --------------------
def train_one(model: nn.Module, head: nn.Module,
              train_ld: DataLoader, val_ld: DataLoader, test_ld: DataLoader,
              epochs: int, lr: float, max_beta: float, warmup_epochs: int,
              device: torch.device, amp: bool, microbatch: Optional[int]):

    model.to(device); head.to(device)
    model = model.to(memory_format=torch.channels_last)
    ce_loss = nn.CrossEntropyLoss(reduction="sum")
    opt = torch.optim.Adam(list(model.parameters()) + list(head.parameters()), lr=lr)
    scheduler = MultiStepLR(opt, milestones=[int(0.6*epochs), int(0.8*epochs)], gamma=0.2)
    scaler = torch.cuda.amp.GradScaler(enabled=amp and torch.cuda.is_available())

    N = len(train_ld.dataset)
    def beta(e): return max_beta * min(1.0, e / max(1, warmup_epochs))

    for epoch in range(1, epochs + 1):
        model.train(); head.train()
        mb = microbatch or 0
        accum = 0

        for x, y in train_ld:
            x = x.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            y = y.to(device, non_blocking=True)

            if mb and x.size(0) > mb:
                # gradient accumulation microbatches
                for i in range(0, x.size(0), mb):
                    xb = x[i:i+mb]; yb = y[i:i+mb]
                    opt.zero_grad(set_to_none=True)
                    with torch.cuda.amp.autocast(enabled=amp and torch.cuda.is_available()):
                        feats, kl = model(xb)
                        logits = head(feats)
                        ce = ce_loss(logits, yb)
                        loss = ce + beta(epoch) * kl / N
                    scaler.scale(loss).backward()
                    scaler.step(opt); scaler.update()
                continue

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=amp and torch.cuda.is_available()):
                feats, kl = model(x)
                logits = head(feats)
                ce = ce_loss(logits, y)
                loss = ce + beta(epoch) * kl / N
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()

        scheduler.step()

        # ---- validation ----
        model.eval(); head.eval()
        correct = 0; count = 0; val_ce_sum = 0.0
        with torch.no_grad():
            for x, y in val_ld:
                x = x.to(device).to(memory_format=torch.channels_last); y = y.to(device)
                with torch.cuda.amp.autocast(enabled=amp and torch.cuda.is_available()):
                    feats, _ = model(x)
                    logits = head(feats)
                    ce = ce_loss(logits, y)
                val_ce_sum += ce.item()
                correct += (logits.argmax(1) == y).sum().item()
                count += y.size(0)
        val_acc = correct / max(1, count)
        val_ce  = val_ce_sum / max(1, count)
        print(f"[epoch {epoch:03d}] val_acc={val_acc:.4f}  val_CE={val_ce:.4f}")

    # ---- test ----
    model.eval(); head.eval()
    correct = 0; count = 0; test_ce_sum = 0.0
    with torch.no_grad():
        for x, y in test_ld:
            x = x.to(device).to(memory_format=torch.channels_last); y = y.to(device)
            with torch.cuda.amp.autocast(enabled=amp and torch.cuda.is_available()):
                feats, _ = model(x)
                logits = head(feats)
                ce = ce_loss(logits, y)
            test_ce_sum += ce.item()
            correct += (logits.argmax(1) == y).sum().item()
            count += y.size(0)
    test_acc = correct / max(1, count)
    test_ce  = test_ce_sum / max(1, count)
    print(f"[test] acc={test_acc:.4f}  CE={test_ce:.4f}")
    return test_acc, test_ce

# -------------------- Main --------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--masks-path", type=str, required=True, default='../tests/Resnet18/99_test1_various_masks/mask_8.3_size.npy')
    p.add_argument("--data-root", type=str, default="./data")
    p.add_argument("--epochs", type=int, default=200)
    p.add_argument("--batch-size", type=int, default=256)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--max-beta", type=float, default=0.1, help="final KL multiplier (warmup target)")
    p.add_argument("--warmup-epochs", type=int, default=20)
    p.add_argument("--seed", type=int, default=1234)
    p.add_argument("--amp", action="store_true")
    p.add_argument("--microbatch", type=int, default=0, help="0=off; else splits each batch")
    args = p.parse_args()

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    masks = load_mask_list(args.masks_path)
    model = MaskDrivenBNNResNet(masks)

    # Derive and attach an unmasked classifier head (penultimate_dim -> 10)
    feat_dim = model.penultimate_dim
    head = nn.Linear(feat_dim, 10)

    # preflight
    with torch.no_grad():
        x = torch.randn(2, 3, 32, 32, device=device).to(memory_format=torch.channels_last)
        model.to(device)
        f, kl = model(x)
        assert f.dim() == 2 and f.shape[1] == feat_dim, f"Unexpected feature shape {tuple(f.shape)}"
        print(f"[mask-check] forward OK — features: {tuple(f.shape)} KL scalar: {float(kl):.2f}")

    train_ld, val_ld, test_ld = cifar10_loaders(args.data_root, args.batch_size)
    train_one(model, head, train_ld, val_ld, test_ld,
              epochs=args.epochs, lr=args.lr, max_beta=args.max_beta,
              warmup_epochs=args.warmup_epochs, device=device,
              amp=args.amp, microbatch=(args.microbatch if args.microbatch>0 else None))

if __name__ == "__main__":
    main()
