### By

### Train a masked Lenet300-100 using HMC. 2800 samples consistentlty gives good results (even 1800 samples does for Lenet3001-00; Dense feedforward networks seem to be easier to train than CNNs)

### Code is the same as LeNet300_MNIST_BNN_pyro_HMC_with_mask.py, but for CIFAR-10 dataset. Refer to that file for details on how the code works. 


import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn

import pyro
import pyro.distributions as dist
from pyro.nn import PyroSample
from pyro.nn.module import PyroModule

from pyro.infer.autoguide import AutoDiagonalNormal

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

from pyro.infer import MCMC, NUTS


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

maskloc = '../tests/LeNet_CIFAR/99_test1_various_masks/mask_1.1_size.npy'

mask = np.load(maskloc, allow_pickle=True)

batchsize_train = 1024
batchsize_test = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.CIFAR10('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize_train)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize_test)




class BayesianConv2d(PyroModule):
    def __init__(self, mask, in_channels, out_channels, kernel_height, kernel_width, padding=0):
        super().__init__()
        self.register_buffer("mask", mask)
        self.padding = padding

        weight_shape = (out_channels, in_channels, kernel_height, kernel_width)
        self.weight = PyroSample(lambda self: dist.Normal(
            torch.zeros_like(self.mask), torch.ones_like(self.mask)).mask(self.mask.bool()).to_event(4))

        self.bias = PyroSample(lambda self: dist.Normal(
            torch.zeros(out_channels, device=device), torch.ones(out_channels, device=device)).to_event(1))

    def forward(self, x):
        weight = self.weight * self.mask
        return F.conv2d(x, weight, self.bias, padding=self.padding)
    


class BayesianLinear(PyroModule):
    def __init__(self, mask, in_size, out_size):
       super().__init__()
       self.mask = mask
       k = np.sqrt(1/in_size)
       self.bias = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(1.0, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(1.0, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

    #    self.bias = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
    #    self.weight = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

    def forward(self, input):
        weight = torch.mul(self.weight, self.mask)
        return F.linear(input, weight, self.bias)






class BayesianLeNet(PyroModule):
    def __init__(self):
        super().__init__()
        self.mask_fc1 = torch.ones(300, 3072).to(device)
        self.mask_fc2 = torch.ones(100, 300).to(device)
        self.mask_fc3 = torch.ones(10, 100).to(device)

        self.flatten = nn.Flatten()
        self.fc1 = BayesianLinear(self.mask_fc1, 3072, 300)
        self.fc2 = BayesianLinear(self.mask_fc2, 300, 100)
        self.fc3 = BayesianLinear(self.mask_fc3, 100, 10)

    def forward(self, x, y=None):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits

    
    def load_mask(self, masklist):
        self.mask_fc1 = torch.from_numpy(masklist[0]).clone().to(device)
        self.mask_fc2 = torch.from_numpy(masklist[1]).clone().to(device)
        self.mask_fc3 = torch.from_numpy(masklist[2]).clone().to(device)

        self.fc1.mask = torch.from_numpy(masklist[0]).clone().to(device)
        self.fc2.mask = torch.from_numpy(masklist[1]).clone().to(device)
        self.fc3.mask = torch.from_numpy(masklist[2]).clone().to(device)
        
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)


def train_bayesian_lenet_hmc(modelte, num_samples=100, warmup_steps=50):
    
    model = BayesianLeNet()
    model.load_mask(mask)

    model.to(device)

    conditioned_model = pyro.condition(model, data={})

    nuts_kernel = NUTS(conditioned_model)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)

    # We'll collect all data at once because HMC doesn't support minibatching
    full_x, full_y = next(iter(DataLoader(dataset1, batch_size=len(dataset1))))

    mcmc.run(full_x.to(device), full_y.to(device))

    return model, mcmc



def evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=10, device="cpu"):
    model.eval()
    posterior_samples = mcmc.get_samples()
    total_loss = 0.0
    total_samples = 0

    total_accuracy = 0

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            logits_samples = []

            for i in range(num_samples):
                sample_idx = np.random.randint(0, posterior_samples["fc3.bias"].shape[0])
                sampled_trace = {k: v[sample_idx] for k, v in posterior_samples.items()}

                conditioned_model = pyro.poutine.condition(model, data=sampled_trace)
                trace = pyro.poutine.trace(conditioned_model).get_trace(x_batch)
                logits = trace.nodes["_RETURN"]["value"]
                logits_samples.append(logits)

            mean_logits = torch.stack(logits_samples).mean(dim=0)
            loss = nn.functional.cross_entropy(mean_logits, y_batch, reduction='sum')
            # print("Logits n truths", mean_logits, y_batch)

            acc = (mean_logits.argmax(dim=1) == y_batch).float().mean()
            # print(f"Accuracy on Test Set: {acc:.4f}")

            total_accuracy += acc.item()


            total_loss += loss.item()
            total_samples += y_batch.size(0)

    avg_loss = total_loss / total_samples
    avg_accuracy = total_accuracy / len(test_loader)
    print(f"Cross-Entropy Loss on Test Set: {avg_loss:.4f}")
    print(f"Accuracy on Test Set: {avg_accuracy:.4f}")
    return avg_loss


modelt = 0

model, mcmc = train_bayesian_lenet_hmc(modelt, num_samples=2800, warmup_steps=200)
evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=1500, device=device)
