### By

### Trains x amount of masked Lenet5 with a different random initialization each time. This is the "random initialization", no Bayesian stuff here, just pure deterministic networks. CIFAR version. 

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader
import argparse 
import random

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--seed",)

args = parser.parse_args()

SEED = int(args.seed)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# if you’re also using CUDA:
torch.cuda.manual_seed_all(SEED)

epochs = 100 #How many epochs to train for
num_trials = 50 #How many times you want to run this in a loop
mask = np.load('../tests/CNN_LeNet5_CIFAR/99_test1_various_masks/mask_17.4_size.npy', allow_pickle=True) #Put in the location to the .npy file



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" #Uncomment if u want CPU for whatever godforsaken reason...
# device

batchsize = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.CIFAR10('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize)



class CustomLinear(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs):                                                        
        super().__init__()                                                            
        self.register_buffer("mask", mask)                                         

        k = np.sqrt(1/inputs)
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs, inputs)), dtype=torch.float32))    
        
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask)  
        out = F.linear(x, weight, self.bias)                                         
        return out     
    

class CustomConv2d(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, padding=0):                                                        
        super().__init__()                                                             
        self.register_buffer("mask", mask)     
        self.padding = padding                                    
        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) 
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size = (outputs, inputs, kernalheight, kernelwidth)), dtype=torch.float32))                                                                                    
        
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        # weight = (self.mask * self.weight)
        weight = torch.mul(self.weight, self.mask)  
        out = F.conv2d(x, weight, self.bias, padding=self.padding)                                         
        return out
    


class LeNet(nn.Module):
    def __init__(self):

        super(LeNet, self).__init__()


        self.mask_conv1 = torch.ones((6, 3, 5, 5), requires_grad=False) 
        self.mask_conv2 = torch.ones((16, 6, 5, 5), requires_grad=False) 
        
        self.mask_fc1 = torch.ones((120, 576), requires_grad=False) 

        self.mask_fc2 = torch.ones((84, 120), requires_grad=False)

        self.mask_fc3 = torch.ones((10, 84), requires_grad=False)    
           
        self.flatten = nn.Flatten()

        self.conv1 = CustomConv2d(self.mask_conv1, 3, 6, 5, 5, padding=2) #3, 64 for CIFAR
        self.conv2 = CustomConv2d(self.mask_conv2, 6, 16, 5, 5)

        self.pooling1 = nn.MaxPool2d(2)

        self.fc1 = CustomLinear(self.mask_fc1, 576, 120) #3136 for CIFAR
        self.fc2 = CustomLinear(self.mask_fc2, 120, 84)
        self.fc3 = CustomLinear(self.mask_fc3, 84, 10)

        self.w1_init = torch.rand(6, 3, 5, 5)
        self.w2_init = torch.rand(16, 6, 5, 5)
        self.w3_init = torch.rand(120, 576)
        self.w4_init = torch.rand(84, 120)
        self.w5_init = torch.rand(10, 84)
        

        self.bias1_init = torch.rand(6)
        self.bias2_init = torch.rand(16)
        self.bias3_init = torch.rand(120)
        self.bias4_init = torch.rand(84)
        self.bias5_init = torch.rand(10)

    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

    def save_init_weights(self):

        self.w1_init = self.conv1.weight.data.clone()
        self.w2_init = self.conv2.weight.data.clone()
        self.w3_init = self.fc1.weight.data.clone()
        self.w4_init = self.fc2.weight.data.clone()
        self.w5_init = self.fc3.weight.data.clone()

        self.bias1_init = self.conv1.bias.data.clone()
        self.bias2_init = self.conv2.bias.data.clone()
        self.bias3_init = self.fc1.bias.data.clone()
        self.bias4_init = self.fc2.bias.data.clone()
        self.bias5_init = self.fc3.bias.data.clone()


    def reset_weights(self):

        self.conv1.weight.data = self.w1_init.clone()
        self.conv2.weight.data = self.w2_init.clone()
        self.fc1.weight.data = self.w3_init.clone()
        self.fc2.weight.data = self.w4_init.clone()
        self.fc3.weight.data = self.w5_init.clone()

        self.conv1.bias.data = self.bias1_init.clone()
        self.conv2.bias.data = self.bias2_init.clone()
        self.fc1.bias.data = self.bias3_init.clone()
        self.fc2.bias.data = self.bias4_init.clone()
        self.fc3.bias.data = self.bias5_init.clone()

    def load_init_weights(self, weightlist, biaslist):

        self.w1_init = torch.from_numpy(weightlist[0]).clone().to(device)
        self.w2_init = torch.from_numpy(weightlist[1]).clone().to(device)
        self.w3_init = torch.from_numpy(weightlist[2]).clone().to(device)
        self.w4_init = torch.from_numpy(weightlist[3]).clone().to(device)
        self.w5_init = torch.from_numpy(weightlist[4]).clone().to(device)

        self.bias1_init = torch.from_numpy(biaslist[0]).clone().to(device)
        self.bias2_init = torch.from_numpy(biaslist[1]).clone().to(device)
        self.bias3_init = torch.from_numpy(biaslist[2]).clone().to(device)
        self.bias4_init = torch.from_numpy(biaslist[3]).clone().to(device)
        self.bias5_init = torch.from_numpy(biaslist[4]).clone().to(device)

    def load_mask(self, masklist):
        self.mask_conv1 = torch.from_numpy(masklist[0]).clone().to(device)
        self.mask_conv2 = torch.from_numpy(masklist[1]).clone().to(device)
        self.mask_fc1 = torch.from_numpy(masklist[2]).clone().to(device)
        self.mask_fc2 = torch.from_numpy(masklist[3]).clone().to(device)
        self.mask_fc3 = torch.from_numpy(masklist[4]).clone().to(device)

        self.conv1.mask = torch.from_numpy(masklist[0]).clone().to(device)
        self.conv2.mask = torch.from_numpy(masklist[1]).clone().to(device)
        self.fc1.mask = torch.from_numpy(masklist[2]).clone().to(device)
        self.fc2.mask = torch.from_numpy(masklist[3]).clone().to(device)
        self.fc3.mask = torch.from_numpy(masklist[4]).clone().to(device)
        
        self.conv1.mask.requires_grad_(False)
        self.conv2.mask.requires_grad_(False)
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)

    def mask_size(self):
        size = (torch.sum(self.fc1.mask) + torch.sum(self.fc2.mask) + torch.sum(self.fc3.mask) + torch.sum(self.conv1.mask)+torch.sum(self.conv2.mask))
        print("Mask Size: {}".format(size))
        return size
    


def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = 0
            loss = 0

            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}", end="\r")

def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = 0

            outputs = model(X_batch)

            predictions.extend(outputs.cpu().numpy())
            actuals.extend(y_batch.cpu().numpy())

    return np.array(predictions), np.array(actuals)

modellosses = []
modelaccs = []

model = LeNet()

for i in range(0, num_trials):
    del model
    model = LeNet()
    model.to(device)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    model.load_mask(mask)
    
    train_model(model, train_loader, criterion, optimizer, num_epochs = epochs)

    t1, t2 = evaluate_model(model, test_loader)

    accuracy = torch.sum(torch.max(torch.tensor(t1), dim=1)[1] == torch.tensor(t2))/len(t2)
    modelaccs.append(accuracy.item())
    loss = criterion(torch.tensor(t1), torch.tensor(t2))
    modellosses.append(loss.item())

    print(f"Trial: {i+1}/{num_trials}                                                        ")
    print(f"Accuracy: {accuracy.item()}")
    print(f"Loss: {loss.item()}")
    del t1
    del t2

print(f"{num_trials} Trials:  Avg Acc: {np.mean(modelaccs)}, Std Acc: {np.std(modelaccs)}, Avg. Loss: {np.mean(modellosses)}, Std Loss: {np.std(modellosses)}")