### By

### Train a masked Lenet5 using HMC. 2800 samples will sometimes give good results. More samples would be necessary for good results. But CNNs behave a little strange; so to consistently get good results, several tens of thousands of samples may be necessary, which we cannot get without crashing our GPU. However, 2800 does sometimes work, and it has shown it can get to, or very near, lottery ticket accuracy on occasion. Another option would be to run multiple intercommunicating chains, but this is not implemented in this code. 

### Code is the same as LeNet300_MNIST_BNN_pyro_HMC_with_mask.py, but for MNIST dataset and LeNet5 architecture. Refer to that file for details on how the code works. 


import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn

import pyro
import pyro.distributions as dist
from pyro.nn import PyroSample
from pyro.nn.module import PyroModule

from pyro.infer.autoguide import AutoDiagonalNormal

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

from pyro.infer import MCMC, NUTS


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
maskloc = "../tests/CNN_LeNet5_MNIST/99_test1_various_masks/mask_1.1_size.npy"
mask = np.load(maskloc, allow_pickle=True)

batchsize_train = 1024
batchsize_test = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize_train)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize_test)



class CustomConv2d(PyroModule):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, padding=0):                                                        
        super().__init__()                                                           
 
        self.mask = mask    
        self.padding = padding                                    

        gain=3
        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) * gain
        # self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size = (outputs, inputs, kernalheight, kernelwidth)), dtype=torch.float32).to(device)) 
        # 
        # self.bias = PyroSample(
        #    prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([outputs]).to_event(1))
        # self.weight = PyroSample(
        #    prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([outputs, inputs, kernalheight, kernelwidth]).to_event(4))
        self.bias = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([outputs]).to_event(1))
        self.weight = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([outputs, inputs, kernalheight, kernelwidth]).to_event(4))                                                                                
                               
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask)  
        out = F.conv2d(x, weight, self.bias, padding=self.padding)                                         
        return out
    

class BayesianLinear(PyroModule):
    def __init__(self, mask, in_size, out_size):
       super().__init__()
       self.mask = mask
       gain = 3
       k = np.sqrt(1/in_size) * gain
    #    self.bias = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
    #    self.weight = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

       self.bias = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(2.0*k, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

    def forward(self, input):
        weight = torch.mul(self.weight, self.mask)
        return F.linear(input, weight, self.bias)
    


class BayesianLeNet(PyroModule):
    def __init__(self):
        super().__init__()
        self.mask_conv1 = torch.ones((6, 1, 5, 5)).to(device)
        self.mask_conv2 = torch.ones((16, 6, 5, 5)).to(device)
        self.mask_fc1 = torch.ones((120, 16 * 5 * 5)).to(device)
        self.mask_fc2 = torch.ones((84, 120)).to(device)
        self.mask_fc3 = torch.ones((10, 84)).to(device)

        self.conv1 = CustomConv2d(self.mask_conv1, 1, 6, 5, 5, padding=2)
        self.conv2 = CustomConv2d(self.mask_conv2, 6, 16, 5, 5)

        self.fc1 = BayesianLinear(self.mask_fc1, 16 * 5 * 5, 120)
        self.fc2 = BayesianLinear(self.mask_fc2, 120, 84)
        self.fc3 = BayesianLinear(self.mask_fc3, 84, 10)

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits

    
    def load_mask(self, masklist):
        self.mask_conv1 = torch.from_numpy(masklist[0]).to(device).clone()
        self.mask_conv2 = torch.from_numpy(masklist[1]).to(device).clone()
        self.mask_fc1 = torch.from_numpy(masklist[2]).to(device).clone()
        self.mask_fc2 = torch.from_numpy(masklist[3]).to(device).clone()
        self.mask_fc3 = torch.from_numpy(masklist[4]).to(device).clone()

        self.conv1.mask = torch.from_numpy(masklist[0]).to(device).clone()
        self.conv2.mask = torch.from_numpy(masklist[1]).to(device).clone()
        self.fc1.mask = torch.from_numpy(masklist[2]).to(device).clone()
        self.fc2.mask = torch.from_numpy(masklist[3]).to(device).clone()
        self.fc3.mask = torch.from_numpy(masklist[4]).to(device).clone()
        
        self.conv1.mask.requires_grad_(False)
        self.conv2.mask.requires_grad_(False)
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)



def train_bayesian_lenet_hmc(modelte, num_samples=100, warmup_steps=50):

    model = BayesianLeNet()
    model.load_mask(mask)
    model.to(device)

    conditioned_model = pyro.condition(model, data={})

    nuts_kernel = NUTS(conditioned_model)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)

    # We'll collect all data at once because HMC doesn't support minibatching
    full_x, full_y = next(iter(DataLoader(dataset1, batch_size=len(dataset1))))
    mcmc.run(full_x.to(device), full_y.to(device))

    return model, mcmc


def evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=10, device="cpu"):
    model.eval()
    posterior_samples = mcmc.get_samples()
    total_loss = 0.0
    total_samples = 0

    total_accuracy = 0

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            logits_samples = []

            for i in range(num_samples):
                sample_idx = np.random.randint(0, posterior_samples["fc3.bias"].shape[0])
                sampled_trace = {k: v[sample_idx] for k, v in posterior_samples.items()}

                conditioned_model = pyro.poutine.condition(model, data=sampled_trace)
                trace = pyro.poutine.trace(conditioned_model).get_trace(x_batch)
                logits = trace.nodes["_RETURN"]["value"]
                logits_samples.append(logits)

            mean_logits = torch.stack(logits_samples).mean(dim=0)
            loss = nn.functional.cross_entropy(mean_logits, y_batch, reduction='sum')
            # print("Logits n truths", mean_logits, y_batch)

            acc = (mean_logits.argmax(dim=1) == y_batch).float().mean()
            # print(f"Accuracy on Test Set: {acc:.4f}")

            total_accuracy += acc.item()


            total_loss += loss.item()
            total_samples += y_batch.size(0)

    avg_loss = total_loss / total_samples
    avg_accuracy = total_accuracy / len(test_loader)
    print(f"Cross-Entropy Loss on Test Set: {avg_loss:.4f}")
    print(f"Accuracy on Test Set: {avg_accuracy:.4f}")
    return avg_loss


modelt = 0

model, mcmc = train_bayesian_lenet_hmc(modelt, num_samples=2800, warmup_steps=200)
evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=1500, device=device)
