### By

### Train a masked Lenet300-100 using HMC. 2800 samples consistentlty gives good results (even 1800 samples does for Lenet3001-00; Dense feedforward networks seem to be easier to train than CNNs)

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn

import pyro
import pyro.distributions as dist
from pyro.nn import PyroSample
from pyro.nn.module import PyroModule

from pyro.infer.autoguide import AutoDiagonalNormal

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

from pyro.infer import MCMC, NUTS


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

maskloc = '../tests/LeNet_MNIST/99_test1_various_masks/mask_1.1_size.npy' ### Point to the .npy file of the lottery ticket mask
mask = np.load(maskloc, allow_pickle=True)

batchsize_train = 1024
batchsize_test = 1024


### Load the Dataset

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize_train)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize_test)


### Create a custom pyro/pytorch layer. These will use pyrosample so they get included into the MCMC run. The mask is forced in the same way as the deterministic networks, to ensure that weights that are 0 stay 0 during training.  

class BayesianConv2d(PyroModule):
    def __init__(self, mask, in_channels, out_channels, kernel_height, kernel_width, padding=0):
        super().__init__()
        self.register_buffer("mask", mask)
        self.padding = padding

        weight_shape = (out_channels, in_channels, kernel_height, kernel_width)
        self.weight = PyroSample(lambda self: dist.Normal(
            torch.zeros_like(self.mask), torch.ones_like(self.mask)).mask(self.mask.bool()).to_event(4)) 

        self.bias = PyroSample(lambda self: dist.Normal(
            torch.zeros(out_channels, device=device), torch.ones(out_channels, device=device)).to_event(1))

    def forward(self, x):
        weight = self.weight * self.mask
        return F.conv2d(x, weight, self.bias, padding=self.padding)
    

class BayesianLinear(PyroModule):
    def __init__(self, mask, in_size, out_size):
       super().__init__()
       self.mask = mask

       k = np.sqrt(1/in_size)
       self.bias = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(1.0, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(torch.tensor(0.0, device=device, dtype=torch.float32), torch.tensor(1.0, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))
    #    self.bias = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size]).to_event(1))          ### Can use different distributions as different priors - the comented out one tends to work better. LeNet300-100 is robust enough to use standard normal as a prior and still get there in 1800/2800 samples, but networks like LeNet5 need a better prior to get to a good approximation in such a small amount of samples (and even then, there is still a lot of variance in Lenet5 MNIST, and Lenet5 CIFAR-10 does not train at all with such a small sample count - it needs vastly more samples to approach the true posterior)
    #    self.weight = PyroSample(
    #        prior=dist.Uniform(torch.tensor(-1.0*k, device=device, dtype=torch.float32), torch.tensor(1.0*k, device=device, dtype=torch.float32)).expand([out_size, in_size]).to_event(2))

    def forward(self, input):

        weight = torch.mul(self.weight, self.mask)
        return F.linear(input, weight, self.bias)






class BayesianLeNet(PyroModule):
    def __init__(self):
        super().__init__()
        self.mask_fc1 = torch.ones(300, 784).to(device)
        self.mask_fc2 = torch.ones(100, 300).to(device)
        self.mask_fc3 = torch.ones(10, 100).to(device)

        self.flatten = nn.Flatten()
        self.fc1 = BayesianLinear(self.mask_fc1, 784, 300)
        self.fc2 = BayesianLinear(self.mask_fc2, 300, 100)
        self.fc3 = BayesianLinear(self.mask_fc3, 100, 10)

    def forward(self, x, y=None):

        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        if y is not None:
            with pyro.plate("data", x.shape[0]):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return logits

    
    def load_mask(self, masklist):

        self.mask_fc1 = torch.from_numpy(masklist[0]).clone().to(device)
        self.mask_fc2 = torch.from_numpy(masklist[1]).clone().to(device)
        self.mask_fc3 = torch.from_numpy(masklist[2]).clone().to(device)


        self.fc1.mask = torch.from_numpy(masklist[0]).clone().to(device)
        self.fc2.mask = torch.from_numpy(masklist[1]).clone().to(device)
        self.fc3.mask = torch.from_numpy(masklist[2]).clone().to(device)
        
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)



def train_bayesian_lenet_hmc(modelte, num_samples=100, warmup_steps=50): ### This was adapted from an older version that took the model as an input. It no longer does that, but all the pipelines were already setup with that in mind, so now it just is blank and defaults to nothing (you'll see later we just assign it to 0)
    
    model = BayesianLeNet()

    model.load_mask(mask)

    model.to(device)


    ### First create the conditioned model for pyro
    conditioned_model = pyro.condition(model, data={})

    ### Then tell pyro we are using HMC with NUTS (I know it says MCMC, but since it's a NUTS kernel, it automatically uses HMC under the hood)
    nuts_kernel = NUTS(conditioned_model)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)

    # We'll collect all data at once because HMC doesn't support minibatching
    full_x, full_y = next(iter(DataLoader(dataset1, batch_size=len(dataset1))))

    ### Aaaand run the HMC! This will take a while. A long while. Sit back, grab a coffee, or a drink, whatever you prefer. Watch a movie, play some games (not on this computer though lol)
    mcmc.run(full_x.to(device), full_y.to(device))

    return model, mcmc



def evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=10, device="cpu"): #We basically just sample from the HMC posterior, and then get the accuracy/losses from the sample. It samples from the trace, so while you can take more samples than from the training, it is kinda pointless. Once all samples are taken, the losses and accuracies are averaged printed. 
    model.eval()
    posterior_samples = mcmc.get_samples()
    total_loss = 0.0
    total_samples = 0

    total_accuracy = 0

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            logits_samples = []

            for i in range(num_samples):
                sample_idx = np.random.randint(0, posterior_samples["fc3.bias"].shape[0])
                sampled_trace = {k: v[sample_idx] for k, v in posterior_samples.items()}

                conditioned_model = pyro.poutine.condition(model, data=sampled_trace)
                trace = pyro.poutine.trace(conditioned_model).get_trace(x_batch)
                logits = trace.nodes["_RETURN"]["value"]
                logits_samples.append(logits)

            mean_logits = torch.stack(logits_samples).mean(dim=0)
            loss = nn.functional.cross_entropy(mean_logits, y_batch, reduction='sum')
            # print("Logits n truths", mean_logits, y_batch)

            acc = (mean_logits.argmax(dim=1) == y_batch).float().mean()
            # print(f"Accuracy on Test Set: {acc:.4f}")

            total_accuracy += acc.item()


            total_loss += loss.item()
            total_samples += y_batch.size(0)

    avg_loss = total_loss / total_samples
    avg_accuracy = total_accuracy / len(test_loader)
    print(f"Cross-Entropy Loss on Test Set: {avg_loss:.4f}")
    print(f"Accuracy on Test Set: {avg_accuracy:.4f}")
    return avg_loss




modelt = 0

model, mcmc = train_bayesian_lenet_hmc(modelt, num_samples=2800, warmup_steps=200)
evaluate_bayesian_lenet_hmc(model, mcmc, test_loader, num_samples=1500, device=device)
