###

### This was also used to generate the toy model masks - just with a different network, settings, etc.
### The instructions are also the same for the CIFAR mask generator, as well as Lenet5 (CNN) mask generators.


### Import the necessary libraries
import os
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(device)

### All masks for the run are stored here
directory_name = "tests/LeNet_MNIST/99_test1"

try:
    os.mkdir("tests/LeNet_MNIST")
except:
    print("Parent directory already present!")

try:
    os.mkdir(directory_name+"_various_masks")
except:
    print("Directory already present!")

### Load Datasets
batchsize = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize)


### Define functions to find the smallest magnitude weights in a network and return a pruned mask

def find_smallest_in_layer(mask, weight_tensor, num_values = 1):
    newweight = weight_tensor.clone().to(device)
    newmask = mask.clone().to(device)

    pruned_index = torch.where(newmask.flatten() == 0)[0]
    # print(pruned_index)

    controlfunc = 0

    prune_index_loc = []
    pruned_index_val = []
    total_amnt = 0

    # print(torch.abs(newweight).flatten().sort())

    while controlfunc == 0:
        # print('triggered')
        for indexval in torch.abs(newweight).flatten().sort()[1]:
            if indexval in pruned_index:
                # print('yes')
                controlfunc = 0
            else:
                prune_index_loc.append(indexval)
                pruned_index_val.append(torch.abs(newweight).flatten()[indexval].item())
                total_amnt = total_amnt+1
                controlfunc = 1
                if total_amnt >= num_values:
                    break
            
        if controlfunc == 0:
            print("All weights pruned? Something is wrong...")

        break

    # print(prune_index_loc)

    if prune_index_loc == -1:
        print("No index? Something is wrong...")
    # else:
    #     newmask.flatten()[prune_index_loc] = 0


    
    return prune_index_loc, pruned_index_val, controlfunc

def smallest_picker(list_of_masks, list_of_weights, total=1):
    
    indicator = 0

    mask_indexes = []
    mask_values = []
    controlfuncs = []

    for mask, weight in zip(list_of_masks, list_of_weights):
        tmpmask = mask.clone().to(device)
        tmpweight = weight.clone().to(device)

        prunelocs, prunevals, controlfunc = find_smallest_in_layer(tmpmask, tmpweight, total)
        mask_indexes.append(prunelocs)
        mask_values.append(prunevals)
        controlfuncs.append(controlfunc)
        del prunelocs
        del prunevals
        del controlfunc
        del tmpmask
        del tmpweight

    masks = []

    for i in range(0, len(list_of_masks)):
        masks.append([])

    if np.sum(controlfuncs) == 0:
        print("No masks!")
        indicator = 1

    else:
        for i in range(0, total):
            first_column = []
            for j in range(0, len(list_of_masks)):
                try:
                    first_column.append(mask_values[j][0])
                except:
                    print("Exception: empty list occured!")
                    first_column.append(torch.tensor(9999999).to(device))
            # first_column_array = np.array(first_column)
            first_column_array = torch.tensor(first_column).to(device)
            # print(first_column_array)
            # min_index = np.argmin(first_column_array)
            min_index = torch.argmin(first_column_array)
            try:
                masks[min_index].append(mask_indexes[min_index].pop(0).item())
            except:
                print("Exception: empty list occured!")
                indicator = 1
                 
            # print(mask_values[min_index])
            try:
                mask_values[min_index].pop(0)
            except:
                print("Exception: empty list occured!")
                indicator = 1

    return masks, indicator
    


### Custom layers to force the mask to be applied during training

class CustomLinear(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs):                                                        
        super().__init__()                                                           
        self.register_buffer("mask", mask)                                         
        k = np.sqrt(1/inputs)
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs, inputs)), dtype=torch.float32))    ### Can change to other distributions as desired
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask) 
        out = F.linear(x, weight, self.bias)                                         
        return out               


class CustomConv2d(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, padding=0):                                                        
        super().__init__()                                                             
        self.register_buffer("mask", mask)     
        self.padding = padding                                    
        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) 
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size = (outputs, inputs, kernalheight, kernelwidth)), dtype=torch.float32))                                                                                    
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask)  
        out = F.conv2d(x, weight, self.bias, padding=self.padding)                                         
        return out



### Define the network
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.mask_fc1 = torch.ones((300, 784), requires_grad=False) 

        self.mask_fc2 = torch.ones((100, 300), requires_grad=False)

        self.mask_fc3 = torch.ones((10, 100), requires_grad=False)    
           
        self.flatten = nn.Flatten()


        self.fc1 = CustomLinear(self.mask_fc1, 784, 300) #3136 for CIFAR
        self.fc2 = CustomLinear(self.mask_fc2, 300, 100)
        self.fc3 = CustomLinear(self.mask_fc3, 100, 10)

        self.w1_init = torch.rand(300, 784)
        self.w2_init = torch.rand(100, 300)
        self.w3_init = torch.rand(10, 100)
        

        self.bias1_init = torch.rand(300)
        self.bias2_init = torch.rand(100)
        self.bias3_init = torch.rand(10)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

    def save_init_weights(self): ### Function to save the original (lottery ticket) initialzation

        self.w1_init = self.fc1.weight.data.clone()
        self.w2_init = self.fc2.weight.data.clone()
        self.w3_init = self.fc3.weight.data.clone()

        self.bias1_init = self.fc1.bias.data.clone()
        self.bias2_init = self.fc2.bias.data.clone()
        self.bias3_init = self.fc3.bias.data.clone()


    def reset_weights(self): ### Resets the weights to their original (lottery ticket) initialization

        self.fc1.weight.data = self.w1_init.clone()
        self.fc2.weight.data = self.w2_init.clone()
        self.fc3.weight.data = self.w3_init.clone()

        self.fc1.bias.data = self.bias1_init.clone()
        self.fc2.bias.data = self.bias2_init.clone()
        self.fc3.bias.data = self.bias3_init.clone()

    def load_init_weights(self, weightlist, biaslist): ### Loads the initial weights and biases from a list of weights (saved to a numpy file - used for testing purposes and a way to force reset the network to original initialization)

        self.w1_init = torch.from_numpy(weightlist[0]).clone().to(device)
        self.w2_init = torch.from_numpy(weightlist[1]).clone().to(device)
        self.w3_init = torch.from_numpy(weightlist[2]).clone().to(device)

        self.bias1_init = torch.from_numpy(biaslist[0]).clone().to(device)
        self.bias2_init = torch.from_numpy(biaslist[1]).clone().to(device)
        self.bias3_init = torch.from_numpy(biaslist[2]).clone().to(device)

    def load_mask(self, masklist): ### Loads the mask from a list of masks (saved to a numpy file - used for testing purposes and a way to force reset the network to a specific mask)

        self.mask_fc1 = torch.from_numpy(masklist[0]).clone().to(device)
        self.mask_fc2 = torch.from_numpy(masklist[1]).clone().to(device)
        self.mask_fc3 = torch.from_numpy(masklist[2]).clone().to(device)

        self.fc1.mask = torch.from_numpy(masklist[0]).clone().to(device)
        self.fc2.mask = torch.from_numpy(masklist[1]).clone().to(device)
        self.fc3.mask = torch.from_numpy(masklist[2]).clone().to(device)
        
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)

    def mask_size(self):
        size = (torch.sum(self.fc1.mask) + torch.sum(self.fc2.mask) + torch.sum(self.fc3.mask))
        print("Mask Size: {}".format(size))
        return size

### Train+test the model

def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = 0
            loss = 0
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}", end='\r')

def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = 0

            outputs = model(X_batch)

            predictions.extend(outputs.cpu().numpy())
            actuals.extend(y_batch.cpu().numpy())

    return np.array(predictions), np.array(actuals)


def lotto_ticket_training(train_dl, test_dl, num_epochs, percent_pruned = .1, rounds_of_pruning = 20): ### Round of pruning is n from p^(1/n)%, so 20 rounds of pruning will remove p^(1/20)% of the weights each round. Percent pruned is our final target pruning percentage, so .1 would be prune to 10% of the original size. Num epochs tells us how long we should train after each round of pruning. 

    lotto_model = LeNet()
    lotto_model.to(device)
    optimizer = optim.Adam(lotto_model.parameters())
    criterion = nn.CrossEntropyLoss()

    pruning_percentages = percent_pruned**(1/rounds_of_pruning)


    lotto_model.save_init_weights()
    
    orig_weights = [lotto_model.fc1.weight.data.clone(), lotto_model.fc2.weight.data.clone(), lotto_model.fc3.weight.data.clone()]
    orig_masks = [lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]
    orig_bias = [lotto_model.fc1.bias.data.clone(), lotto_model.fc2.bias.data.clone(), lotto_model.fc3.bias.data.clone()]

    initweights = np.array([orig_weights[0].cpu().detach().numpy(), orig_weights[1].cpu().detach().numpy(), orig_weights[2].cpu().detach().numpy()], dtype=object)
    np.save(f"{directory_name}_initial_init_weights.npy", initweights)

    initbias = np.array([orig_bias[0].cpu().detach().numpy(), orig_bias[1].cpu().detach().numpy(), orig_bias[2].cpu().detach().numpy()], dtype=object)
    np.save(f"{directory_name}_initial_init_bias.npy", initbias)

    numpymask = np.array([orig_masks[0].cpu().detach().numpy(), orig_masks[1].cpu().detach().numpy(), orig_masks[2].cpu().detach().numpy()], dtype=object)
    np.save(f"{directory_name}_mask.npy", numpymask)


    losses = []

    tmpmask = [lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]

    orig_size = torch.sum(orig_masks[0])+torch.sum(orig_masks[1])+torch.sum(orig_masks[2])

    orig_loss = 0

    orig_loss_marker = 0

    orig_accuracy = 0


    stopping_criterion = 0
    prune_amnt = round((1.0-pruning_percentages)*orig_size.item())


    for countingstuff in range(rounds_of_pruning):
        del lotto_model

        torch.cuda.empty_cache()

        lotto_model = LeNet()
        wieghtloc = f"{directory_name}_initial_init_weights.npy"
        biasloc = f"{directory_name}_initial_init_bias.npy"
        maskloc = f"{directory_name}_mask.npy"

        weightstuff = np.load(wieghtloc, allow_pickle=True)
        biasstuff = np.load(biasloc, allow_pickle=True)
        maskstuff = np.load(maskloc, allow_pickle=True)

        lotto_model.load_init_weights(weightstuff, biasstuff)
        lotto_model.reset_weights()
        lotto_model.load_mask(maskstuff)
        lotto_model.to(device)

        model_optimizer = optim.Adam(lotto_model.parameters())
        model_criterion = nn.CrossEntropyLoss()
        lotto_model.to(device)


        train_model(lotto_model, train_dl, model_criterion, model_optimizer, num_epochs)
        pre_prune_test_res = evaluate_model(lotto_model, test_dl)
        pre_prune_test_loss = model_criterion(torch.tensor(pre_prune_test_res[0]), torch.tensor(pre_prune_test_res[1]))
        pre_prune_test_acc = torch.sum(torch.max(torch.tensor(pre_prune_test_res[0]), dim=1)[1] == torch.tensor(pre_prune_test_res[1]))/len(pre_prune_test_res[1])

        weights = [lotto_model.fc1.weight.data, lotto_model.fc2.weight.data, lotto_model.fc3.weight.data]
        masks = [lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]
        tmpsize = torch.sum(masks[0]) + torch.sum(masks[1]) + torch.sum(masks[2])

        prune_amnt = round((1.0-pruning_percentages)*tmpsize.item())

        remove_mask_indeces, indicator = smallest_picker(masks, weights, total=prune_amnt)

        if orig_loss_marker == 0:
            orig_loss = pre_prune_test_loss
            orig_accuracy = pre_prune_test_acc
            print("Loss of the Original Network: {}, Accuracy of Original Network: {}".format(orig_loss, orig_accuracy))
            orig_loss_marker = 1

            lotto_model.mask_fc1.flatten()[remove_mask_indeces[0]] = 0
            lotto_model.mask_fc2.flatten()[remove_mask_indeces[1]] = 0
            lotto_model.mask_fc3.flatten()[remove_mask_indeces[2]] = 0

            lotto_model.fc1.mask.flatten()[remove_mask_indeces[0]] = 0
            lotto_model.fc2.mask.flatten()[remove_mask_indeces[1]] = 0
            lotto_model.fc3.mask.flatten()[remove_mask_indeces[2]] = 0
        
        else:
            lotto_model.mask_fc1.flatten()[remove_mask_indeces[0]] = 0
            lotto_model.mask_fc2.flatten()[remove_mask_indeces[1]] = 0
            lotto_model.mask_fc3.flatten()[remove_mask_indeces[2]] = 0

            lotto_model.fc1.mask.flatten()[remove_mask_indeces[0]] = 0
            lotto_model.fc2.mask.flatten()[remove_mask_indeces[1]] = 0
            lotto_model.fc3.mask.flatten()[remove_mask_indeces[2]] = 0

        post_prune = evaluate_model(lotto_model, test_dl)
        post_prune_loss = model_criterion(torch.tensor(post_prune[0]), torch.tensor(post_prune[1]))
        post_prune_acc = torch.sum(torch.max(torch.tensor(post_prune[0]), dim=1)[1] == torch.tensor(post_prune[1]))/len(post_prune[1])



        newsize = 0

        if indicator == 1:
            print("Stuff is broken!") ### This shouldn't ever trigger, but in rare cases in extreme pruning amounts (think, pruning 999 weights from a 1000 weight network) will trigger since it will try to remove the last weight from a layer. One could concievably put in a check into the pruning function to ensure it never prunes in such a way as to result in 0 weights, but this was not a problem in our experiments, and this solution is much more complex than necessary.
            stopping_criterion = 1
            break
        
        else:
            newsize = torch.sum(masks[0])+torch.sum(masks[1])+torch.sum(masks[2])
            losses.append([post_prune_loss, newsize])
            sizer = str(round(((newsize.item()/orig_size.item()))*1000)/10)
            
            savemask = np.array([lotto_model.mask_fc1.cpu().detach().numpy(), lotto_model.mask_fc2.cpu().detach().numpy(), lotto_model.mask_fc3.cpu().detach().numpy()], dtype=object)
            
            np.save(f"{directory_name}_mask.npy", savemask)
            np.save(f"{directory_name}_various_masks/mask_{sizer}_size.npy", savemask) ### Save intermediate masks. Useful if the final one is dumb and doesn't work. Can happen if you prune too much. 


        if type(newsize) == int:
            percentagino = round(((newsize/orig_size.item()))*1000)/10
        else:
            percentagino = round(((newsize.item()/orig_size.item()))*1000)/10


        print("Model pruned to {}%. New Loss = {}, Orig Loss = {}, New Accuracy = {}, Original Accuracy = {}".format(percentagino, post_prune_loss, orig_loss, post_prune_acc, orig_accuracy))

        lotto_model.reset_weights()

        ### All important things are saved to disk. The rest are deleted to ensure caches are cleared and no interference occurs. 
        del tmpsize
        del newsize
        del pre_prune_test_loss
        del model_optimizer
        del model_criterion
        del maskstuff
        del weightstuff
        del biasstuff

        
    
    return tmpmask, orig_weights, orig_bias, losses





### Run the thing yeah!
mask, weights, bias, losses = lotto_ticket_training(train_loader, test_loader, num_epochs=50, percent_pruned = .01, rounds_of_pruning = 50)

### Get that mask, the lottery ticket initialization, and save it to disk, yeah!
numpymask = np.array([mask[0].cpu().detach().numpy(), mask[1].cpu().detach().numpy(), mask[2].cpu().detach().numpy()], dtype=object)

initweights = np.array([weights[0].cpu().detach().numpy(), weights[1].cpu().detach().numpy(), weights[2].cpu().detach().numpy()], dtype=object)

initbias = np.array([bias[0].cpu().detach().numpy(), bias[1].cpu().detach().numpy(), bias[2].cpu().detach().numpy()], dtype=object)

np.save(f"{directory_name}_mask.npy", numpymask)

np.save(f"{directory_name}_initial_weights.npy", numpymask)

np.save(f"{directory_name}_initial_bias.npy", initbias)

lossarray = np.array(losses)

np.save(f"{directory_name}_losses.npy", lossarray)

### Perfecto - one piping hot lottery ticket mask ready and saved to disk (along with a bunch of other masks at various pruning percentages)



