### By

### Lenet5 on MNIST, for details of specific functions and how they work etc, please see LeNet_MNIST_lotto_ticket_generator.py (exact same thing, different network)
import os
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(device)

directory_name = "tests/CNN_LeNet5_MNIST/99_test1"

try:
    os.mkdir("tests/CNN_LeNet5_MNIST")
except:
    print("Parent directory already present!")

try:
    os.mkdir(directory_name+"_various_masks")
except:
    print("Directory already present!")

batchsize = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize)

def find_smallest_in_layer(mask, weight_tensor, num_values = 1):
    newweight = weight_tensor.clone().to(device)
    newmask = mask.clone().to(device)

    pruned_index = torch.where(newmask.flatten() == 0)[0]
    # print(pruned_index)

    controlfunc = 0

    prune_index_loc = []
    pruned_index_val = []
    total_amnt = 0

    # print(torch.abs(newweight).flatten().sort())

    while controlfunc == 0:
        # print('triggered')
        for indexval in torch.abs(newweight).flatten().sort()[1]:
            if indexval in pruned_index:
                # print('yes')
                controlfunc = 0
            else:
                # print(indexval)
                # print(torch.abs(newweight).flatten()[indexval])
                prune_index_loc.append(indexval)
                pruned_index_val.append(torch.abs(newweight).flatten()[indexval].item())
                total_amnt = total_amnt+1
                controlfunc = 1
                if total_amnt >= num_values:
                    break
            
        if controlfunc == 0:
            print("All weights pruned? Something is wrong...")

        break

    # print(prune_index_loc)

    if prune_index_loc == -1:
        print("No index? Something is wrong...")
    # else:
    #     newmask.flatten()[prune_index_loc] = 0


    
    return prune_index_loc, pruned_index_val, controlfunc

def smallest_picker(list_of_masks, list_of_weights, total=1):
    
    indicator = 0

    mask_indexes = []
    mask_values = []
    controlfuncs = []

    for mask, weight in zip(list_of_masks, list_of_weights):
        tmpmask = mask.clone().to(device)
        tmpweight = weight.clone().to(device)

        prunelocs, prunevals, controlfunc = find_smallest_in_layer(tmpmask, tmpweight, total)
        mask_indexes.append(prunelocs)
        mask_values.append(prunevals)
        controlfuncs.append(controlfunc)
        del prunelocs
        del prunevals
        del controlfunc
        del tmpmask
        del tmpweight

    masks = []

    for i in range(0, len(list_of_masks)):
        masks.append([])

    if np.sum(controlfuncs) == 0:
        print("No masks!")
        indicator = 1

    else:
        for i in range(0, total):
            first_column = []
            for j in range(0, len(list_of_masks)):
                try:
                    first_column.append(mask_values[j][0])
                except:
                    print("Exception: empty list occured!")
                    first_column.append(torch.tensor(9999999).to(device))
            # first_column_array = np.array(first_column)
            first_column_array = torch.tensor(first_column).to(device)
            # print(first_column_array)
            # min_index = np.argmin(first_column_array)
            min_index = torch.argmin(first_column_array)
            try:
                masks[min_index].append(mask_indexes[min_index].pop(0).item())
            except:
                print("Exception: empty list occured!")
                indicator = 1
                 
            # print(mask_values[min_index])
            try:
                mask_values[min_index].pop(0)
            except:
                print("Exception: empty list occured!")
                indicator = 1

    return masks, indicator
    


# %%
class CustomLinear(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs):                                                        
        super().__init__()                                                            
        self.register_buffer("mask", mask)                                         

        k = np.sqrt(1/inputs)
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs, inputs)), dtype=torch.float32))    

        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask)  
        out = F.linear(x, weight, self.bias)                                         
        return out               

# %%
class CustomConv2d(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, padding=0):                                                        
        super().__init__()                                                           
 
        self.register_buffer("mask", mask)     
        self.padding = padding                                    

        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) 
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size = (outputs, inputs, kernalheight, kernelwidth)), dtype=torch.float32))                                                                                    
        
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask)  
        out = F.conv2d(x, weight, self.bias, padding=self.padding)                                         
        return out


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.mask_conv1 = torch.ones((6, 1, 5, 5), requires_grad=False) 
        self.mask_conv2 = torch.ones((16, 6, 5, 5), requires_grad=False) 
        

        self.mask_fc1 = torch.ones((120, 16*5*5), requires_grad=False) 

        self.mask_fc2 = torch.ones((84, 120), requires_grad=False)

        self.mask_fc3 = torch.ones((10, 84), requires_grad=False)    
           
        self.flatten = nn.Flatten()

        self.conv1 = CustomConv2d(self.mask_conv1, 1, 6, 5, 5, padding=2) #3, 6, 5, 5 for CIFAR
        self.conv2 = CustomConv2d(self.mask_conv2, 6, 16, 5, 5)

        self.pooling1 = nn.MaxPool2d(2)

        self.fc1 = CustomLinear(self.mask_fc1, 16*5*5, 120) #3136 for CIFAR
        self.fc2 = CustomLinear(self.mask_fc2, 120, 84)
        self.fc3 = CustomLinear(self.mask_fc3, 84, 10)

        self.w1_init = torch.rand(6, 1, 5, 5)
        self.w2_init = torch.rand(16, 6, 5, 5)
        self.w3_init = torch.rand(120, 16*5*5)
        self.w4_init = torch.rand(84, 120)
        self.w5_init = torch.rand(10, 84)
        

        self.bias1_init = torch.rand(6)
        self.bias2_init = torch.rand(16)
        self.bias3_init = torch.rand(120)
        self.bias4_init = torch.rand(84)
        self.bias5_init = torch.rand(10)

    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

    def save_init_weights(self):

        self.w1_init = self.conv1.weight.data.clone()
        self.w2_init = self.conv2.weight.data.clone()
        self.w3_init = self.fc1.weight.data.clone()
        self.w4_init = self.fc2.weight.data.clone()
        self.w5_init = self.fc3.weight.data.clone()

        self.bias1_init = self.conv1.bias.data.clone()
        self.bias2_init = self.conv2.bias.data.clone()
        self.bias3_init = self.fc1.bias.data.clone()
        self.bias4_init = self.fc2.bias.data.clone()
        self.bias5_init = self.fc3.bias.data.clone()


    def reset_weights(self):

        self.conv1.weight.data = self.w1_init.clone()
        self.conv2.weight.data = self.w2_init.clone()
        self.fc1.weight.data = self.w3_init.clone()
        self.fc2.weight.data = self.w4_init.clone()
        self.fc3.weight.data = self.w5_init.clone()

        self.conv1.bias.data = self.bias1_init.clone()
        self.conv2.bias.data = self.bias2_init.clone()
        self.fc1.bias.data = self.bias3_init.clone()
        self.fc2.bias.data = self.bias4_init.clone()
        self.fc3.bias.data = self.bias5_init.clone()

    def load_init_weights(self, weightlist, biaslist):

        self.w1_init = torch.from_numpy(weightlist[0]).clone().to(device)
        self.w2_init = torch.from_numpy(weightlist[1]).clone().to(device)
        self.w3_init = torch.from_numpy(weightlist[2]).clone().to(device)
        self.w4_init = torch.from_numpy(weightlist[3]).clone().to(device)
        self.w5_init = torch.from_numpy(weightlist[4]).clone().to(device)

        self.bias1_init = torch.from_numpy(biaslist[0]).clone().to(device)
        self.bias2_init = torch.from_numpy(biaslist[1]).clone().to(device)
        self.bias3_init = torch.from_numpy(biaslist[2]).clone().to(device)
        self.bias4_init = torch.from_numpy(biaslist[3]).clone().to(device)
        self.bias5_init = torch.from_numpy(biaslist[4]).clone().to(device)

    def load_mask(self, masklist):
        self.mask_conv1 = torch.from_numpy(masklist[0]).clone().to(device)
        self.mask_conv2 = torch.from_numpy(masklist[1]).clone().to(device)
        self.mask_fc1 = torch.from_numpy(masklist[2]).clone().to(device)
        self.mask_fc2 = torch.from_numpy(masklist[3]).clone().to(device)
        self.mask_fc3 = torch.from_numpy(masklist[4]).clone().to(device)

        self.conv1.mask = torch.from_numpy(masklist[0]).clone().to(device)
        self.conv2.mask = torch.from_numpy(masklist[1]).clone().to(device)
        self.fc1.mask = torch.from_numpy(masklist[2]).clone().to(device)
        self.fc2.mask = torch.from_numpy(masklist[3]).clone().to(device)
        self.fc3.mask = torch.from_numpy(masklist[4]).clone().to(device)
        
        self.conv1.mask.requires_grad_(False)
        self.conv2.mask.requires_grad_(False)
        self.fc1.mask.requires_grad_(False)
        self.fc2.mask.requires_grad_(False)
        self.fc3.mask.requires_grad_(False)

    def mask_size(self):
        size = (torch.sum(self.fc1.mask) + torch.sum(self.fc2.mask) + torch.sum(self.fc3.mask) + torch.sum(self.conv1.mask)+torch.sum(self.conv2.mask))
        print("Mask Size: {}".format(size))
        return size



def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = 0
            loss = 0

            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}", end='\r')

def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = 0

            outputs = model(X_batch)

            predictions.extend(outputs.cpu().numpy())
            actuals.extend(y_batch.cpu().numpy())

    return np.array(predictions), np.array(actuals)

# %%
def lotto_ticket_training(train_dl, test_dl, num_epochs, percent_pruned = .1, rounds_of_pruning = 20):

    lotto_model = LeNet()
    lotto_model.to(device)
    optimizer = optim.Adam(lotto_model.parameters())
    criterion = nn.CrossEntropyLoss()

    pruning_percentages = percent_pruned**(1/rounds_of_pruning)


    lotto_model.save_init_weights()
    
    orig_weights = [lotto_model.conv1.weight.data.clone(), lotto_model.conv2.weight.data.clone(), lotto_model.fc1.weight.data.clone(), lotto_model.fc2.weight.data.clone(), lotto_model.fc3.weight.data.clone()]
    orig_masks = [lotto_model.mask_conv1.clone(), lotto_model.mask_conv2.clone(), lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]
    orig_bias = [lotto_model.conv1.bias.data.clone(), lotto_model.conv2.bias.data.clone(), lotto_model.fc1.bias.data.clone(), lotto_model.fc2.bias.data.clone(), lotto_model.fc3.bias.data.clone()]

    initweights = np.array([orig_weights[0].cpu().detach().numpy(), orig_weights[1].cpu().detach().numpy(), orig_weights[2].cpu().detach().numpy(), orig_weights[3].cpu().detach().numpy(), orig_weights[4].cpu().detach().numpy()], dtype=object)
    np.save(f"{directory_name}_initial_init_weights.npy", initweights)

    initbias = np.array([orig_bias[0].cpu().detach().numpy(), orig_bias[1].cpu().detach().numpy(), orig_bias[2].cpu().detach().numpy(), orig_bias[3].cpu().detach().numpy(), orig_bias[4].cpu().detach().numpy()], dtype=object)
    np.save(f"{directory_name}_initial_init_bias.npy", initbias)

    numpymask = np.array([orig_masks[0].cpu().detach().numpy(), orig_masks[1].cpu().detach().numpy(), orig_masks[2].cpu().detach().numpy(), orig_masks[3].cpu().detach().numpy(), orig_masks[4].cpu().detach().numpy()], dtype=object)
    np.save(f"{directory_name}_mask.npy", numpymask)


    losses = []

    tmpmask = [lotto_model.mask_conv1.clone(), lotto_model.mask_conv2.clone(), lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]

    orig_size = torch.sum(orig_masks[0])+torch.sum(orig_masks[1])+torch.sum(orig_masks[2])+torch.sum(orig_masks[3])+torch.sum(orig_masks[4])

    orig_loss = 0

    orig_loss_marker = 0

    orig_accuracy = 0

    lotto_model.mask_size()


    stopping_criterion = 0
    prune_amnt = round((1.0-pruning_percentages)*orig_size.item())

    for countingstuff in range(rounds_of_pruning):
        del lotto_model
        # del tmpmask
        torch.cuda.empty_cache()

        lotto_model = LeNet()
        wieghtloc = f"{directory_name}_initial_init_weights.npy"
        biasloc = f"{directory_name}_initial_init_bias.npy"
        maskloc = f"{directory_name}_mask.npy"

        weightstuff = np.load(wieghtloc, allow_pickle=True)
        biasstuff = np.load(biasloc, allow_pickle=True)
        maskstuff = np.load(maskloc, allow_pickle=True)

        lotto_model.load_init_weights(weightstuff, biasstuff)
        lotto_model.reset_weights()
        lotto_model.load_mask(maskstuff)
        lotto_model.to(device)

        model_optimizer = optim.Adam(lotto_model.parameters())
        model_criterion = nn.CrossEntropyLoss()
        lotto_model.to(device)


        train_model(lotto_model, train_dl, model_criterion, model_optimizer, num_epochs)
        pre_prune_test_res = evaluate_model(lotto_model, test_dl)
        pre_prune_test_loss = model_criterion(torch.tensor(pre_prune_test_res[0]), torch.tensor(pre_prune_test_res[1]))
        pre_prune_test_acc = torch.sum(torch.max(torch.tensor(pre_prune_test_res[0]), dim=1)[1] == torch.tensor(pre_prune_test_res[1]))/len(pre_prune_test_res[1])
        tmpsize = lotto_model

        weights = [lotto_model.conv1.weight.data, lotto_model.conv2.weight.data, lotto_model.fc1.weight.data, lotto_model.fc2.weight.data, lotto_model.fc3.weight.data]
        masks = [lotto_model.mask_conv1.clone(), lotto_model.mask_conv2.clone(), lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]
        tmpsize = torch.sum(masks[0]) + torch.sum(masks[1]) + torch.sum(masks[2]) + torch.sum(masks[3]) + torch.sum(masks[4])

        prune_amnt = round((1.0-pruning_percentages)*tmpsize.item())

        remove_mask_indeces, indicator = smallest_picker(masks, weights, total=prune_amnt)

        if orig_loss_marker == 0:
            orig_loss = pre_prune_test_loss
            orig_accuracy = pre_prune_test_acc
            print("Loss of the Original Network: {}, Accuracy of Original Network: {}".format(orig_loss, orig_accuracy))
            orig_loss_marker = 1

        
        if indicator == 0:
            
            lotto_model.mask_conv1.flatten()[remove_mask_indeces[0]] = 0
            lotto_model.mask_conv2.flatten()[remove_mask_indeces[1]] = 0
            lotto_model.mask_fc1.flatten()[remove_mask_indeces[2]] = 0
            lotto_model.mask_fc2.flatten()[remove_mask_indeces[3]] = 0
            lotto_model.mask_fc3.flatten()[remove_mask_indeces[4]] = 0

            lotto_model.conv1.mask.flatten()[remove_mask_indeces[0]] = 0
            lotto_model.conv2.mask.flatten()[remove_mask_indeces[1]] = 0
            lotto_model.fc1.mask.flatten()[remove_mask_indeces[2]] = 0
            lotto_model.fc2.mask.flatten()[remove_mask_indeces[3]] = 0
            lotto_model.fc3.mask.flatten()[remove_mask_indeces[4]] = 0
        
        else:
            None

        post_prune = evaluate_model(lotto_model, test_dl)
        post_prune_loss = model_criterion(torch.tensor(post_prune[0]), torch.tensor(post_prune[1]))
        post_prune_acc = torch.sum(torch.max(torch.tensor(post_prune[0]), dim=1)[1] == torch.tensor(post_prune[1]))/len(post_prune[1])


        newsize = 0


        if indicator == 1:
            print("Stuff is broken!")
            stopping_criterion = 1
            break
        
        else:
            newsize = torch.sum(masks[0])+torch.sum(masks[1])+torch.sum(masks[2])+torch.sum(masks[3])+torch.sum(masks[4])
            losses.append([post_prune_loss.item(), newsize.item()])
            sizer = str(round(((newsize.item()/orig_size.item()))*1000)/10)
            
            savemask = np.array([lotto_model.mask_conv1.cpu().detach().numpy(), lotto_model.mask_conv2.cpu().detach().numpy(), lotto_model.mask_fc1.cpu().detach().numpy(), lotto_model.mask_fc2.cpu().detach().numpy(), lotto_model.mask_fc3.cpu().detach().numpy()], dtype=object)
            
            np.save(f"{directory_name}_mask.npy", savemask)
            np.save(f"{directory_name}_various_masks/mask_{sizer}_size.npy", savemask)


        if type(newsize) == int:
            percentagino = round(((newsize/orig_size.item()))*1000)/10
        else:
            percentagino = round(((newsize.item()/orig_size.item()))*1000)/10

        print("Model pruned to {}%. New Loss = {}, Orig Loss = {}, New Accuracy = {}, Original Accuracy = {}".format(percentagino, post_prune_loss, orig_loss, post_prune_acc, orig_accuracy))

        lotto_model.reset_weights()

        del tmpsize
        del newsize
        del pre_prune_test_loss
        del model_optimizer
        del model_criterion
        del maskstuff
        del weightstuff
        del biasstuff

        
    
    return tmpmask, orig_weights, orig_bias, losses






mask, weights, bias, losses = lotto_ticket_training(train_loader, test_loader, num_epochs=50, percent_pruned = .01, rounds_of_pruning = 50)

numpymask = np.array([mask[0].cpu().detach().numpy(), mask[1].cpu().detach().numpy(), mask[2].cpu().detach().numpy(), mask[3].cpu().detach().numpy(), mask[4].cpu().detach().numpy()], dtype=object)

initweights = np.array([weights[0].cpu().detach().numpy(), weights[1].cpu().detach().numpy(), weights[2].cpu().detach().numpy(), weights[3].cpu().detach().numpy(), weights[4].cpu().detach().numpy()], dtype=object)

initbias = np.array([bias[0].cpu().detach().numpy(), bias[1].cpu().detach().numpy(), bias[2].cpu().detach().numpy(), bias[3].cpu().detach().numpy(), bias[4].cpu().detach().numpy()], dtype=object)

np.save(f"{directory_name}_mask.npy", numpymask)

np.save(f"{directory_name}_initial_weights.npy", numpymask)

np.save(f"{directory_name}_initial_bias.npy", initbias)

lossarray = np.array(losses)

np.save(f"{directory_name}_losses.npy", lossarray)




