### By

### Train a masked Lenet5 using HMC. 2800 samples will now no longer give good results, as the mask is much larger (combined with the fact this is a CNN). More samples would be necessary for good results. But CNNs behave a little strange; so to consistently get good results, several tens of thousands of samples, potentially even hundreds of thousands of samples, may be necessary, which we cannot get without crashing our GPU. We do include this code if you want to try and get there, but you will either need a much better setup (better parallelization to utilize more GPU memory), or to use a custom way of running multiple intercommunicating chains at once. 

### Code is the same as LeNet300_MNIST_BNN_pyro_HMC_with_mask.py, but for CIFAR-10 dataset and LeNet5 architecture. Refer to that file for details on how the code works. 


import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn

import pyro
import pyro.distributions as dist
from pyro.nn import PyroSample
from pyro.nn.module import PyroModule

from pyro.infer.autoguide import AutoDiagonalNormal

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

from pyro.infer import MCMC, NUTS


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
maskloc = "../tests/CNN_LeNet5_CIFAR/99_test1_various_masks/mask_1.1_size.npy"
mask = np.load(maskloc, allow_pickle=True)

batchsize_train = 1024
batchsize_test = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.CIFAR10('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize_train)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize_test)



class CustomConv2d(PyroModule):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, padding=0):                                                        
        super().__init__()                                                            
        self.mask = mask    
        self.padding = padding                                     
        gain=3
        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) * gain

        # self.bias = PyroSample(
        #    prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([outputs]).to_event(1))
        # self.weight = PyroSample(
        #    prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([outputs, inputs, kernalheight, kernelwidth]).to_event(4))
        self.bias = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([outputs]).to_event(1))
        self.weight = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([outputs, inputs, kernalheight, kernelwidth]).to_event(4))                                                                                
                      
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask)  
        out = F.conv2d(x, weight, self.bias, padding=self.padding)                                         
        return out
    

class BayesianLinear(PyroModule):
    def __init__(self, mask, in_size, out_size):
       super().__init__()
       self.mask = mask
       gain = 3
       k = np.sqrt(1/in_size) * gain
    #    self.bias = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
    #    self.weight = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

       self.bias = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

    def forward(self, input):
        weight = torch.mul(self.weight, self.mask)
        return F.linear(input, weight, self.bias)
    


class BayesianLeNet(PyroModule):
    def __init__(self):
        super().__init__()
        self.mask_conv1 = torch.ones((6, 3, 5, 5)).to(device)
        self.mask_conv2 = torch.ones((16, 6, 5, 5)).to(device)
        self.mask_fc1 = torch.ones((120, 576)).to(device)
        self.mask_fc2 = torch.ones((84, 120)).to(device)
        self.mask_fc3 = torch.ones((10, 84)).to(device)

        self.conv1 = CustomConv2d(self.mask_conv1, 3, 6, 5, 5, padding=2)
        self.conv2 = CustomConv2d(self.mask_conv2, 6, 16, 5, 5)

        self.fc1 = BayesianLinear(self.mask_fc1, 576, 120)
        self.fc2 = BayesianLinear(self.mask_fc2, 120, 84)
        self.fc3 = BayesianLinear(self.mask_fc3, 84, 10)

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits

    
    def load_mask(self, masklist):
        self.mask_conv1 = torch.from_numpy(masklist[0]).to(device).clone()
        self.mask_conv2 = torch.from_numpy(masklist[1]).to(device).clone()
        self.mask_fc1 = torch.from_numpy(masklist[2]).to(device).clone()
        self.mask_fc2 = torch.from_numpy(masklist[3]).to(device).clone()
        self.mask_fc3 = torch.from_numpy(masklist[4]).to(device).clone()

        self.conv1.mask = torch.from_numpy(masklist[0]).to(device).clone()
        self.conv2.mask = torch.from_numpy(masklist[1]).to(device).clone()
        self.fc1.mask = torch.from_numpy(masklist[2]).to(device).clone()
        self.fc2.mask = torch.from_numpy(masklist[3]).to(device).clone()
        self.fc3.mask = torch.from_numpy(masklist[4]).to(device).clone()
        
        self.conv1.mask.requires_grad_(False)
        self.conv2.mask.requires_grad_(False)
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)


def train_bayesian_lenet_hmc(modelte, num_samples=100, warmup_steps=50):

    model = BayesianLeNet()
    model.load_mask(mask)

    model.to(device)

    conditioned_model = pyro.condition(model, data={})

    nuts_kernel = NUTS(conditioned_model)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)

    # We'll collect all data at once because HMC doesn't support minibatching.
    full_x, full_y = next(iter(DataLoader(dataset1, batch_size=len(dataset1))))

    mcmc.run(full_x.to(device), full_y.to(device))

    return model, mcmc

def evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=10, device="cpu"):
    model.eval()
    posterior_samples = mcmc.get_samples()
    total_loss = 0.0
    total_samples = 0

    total_accuracy = 0

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            logits_samples = []

            for i in range(num_samples):
                sample_idx = np.random.randint(0, posterior_samples["fc3.bias"].shape[0])
                sampled_trace = {k: v[sample_idx] for k, v in posterior_samples.items()}

                conditioned_model = pyro.poutine.condition(model, data=sampled_trace)
                trace = pyro.poutine.trace(conditioned_model).get_trace(x_batch)
                logits = trace.nodes["_RETURN"]["value"]
                logits_samples.append(logits)

            mean_logits = torch.stack(logits_samples).mean(dim=0)
            loss = nn.functional.cross_entropy(mean_logits, y_batch, reduction='sum')
            # print("Logits n truths", mean_logits, y_batch)

            acc = (mean_logits.argmax(dim=1) == y_batch).float().mean()
            # print(f"Accuracy on Test Set: {acc:.4f}")

            total_accuracy += acc.item()


            total_loss += loss.item()
            total_samples += y_batch.size(0)

    avg_loss = total_loss / total_samples
    avg_accuracy = total_accuracy / len(test_loader)
    print(f"Cross-Entropy Loss on Test Set: {avg_loss:.4f}")
    print(f"Accuracy on Test Set: {avg_accuracy:.4f}")
    return avg_loss


modelt = 0

model, mcmc = train_bayesian_lenet_hmc(modelt, num_samples=2800, warmup_steps=200)
evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=1500, device=device)
