### By Peter Marsh

### This was also used to generate the toy model masks - just with a different network, settings, etc.
### The instructions are also the same for the CIFAR mask generator, as well as Lenet5 (CNN) mask generators.


### Import the necessary libraries
import os
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader
import gc


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(device)

### All masks for the run are stored here
directory_name = "tests/Resnet18_CIFAR/99_test1"

try:
    os.mkdir("tests/Resnet18_CIFAR")
except:
    print("Parent directory already present!")

try:
    os.mkdir(directory_name+"_various_masks")
except:
    print("Directory already present!")

### Load Datasets
batchsize = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.CIFAR10('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize)


### Define functions to find the smallest magnitude weights in a network and return a pruned mask

def find_smallest_in_layer(mask, weight_tensor, num_values = 1):
    newweight = weight_tensor.clone().to(device)
    newmask = mask.clone().to(device)

    pruned_index = torch.where(newmask.flatten() == 0)[0]
    # print(pruned_index)

    controlfunc = 0

    prune_index_loc = []
    pruned_index_val = []
    total_amnt = 0

    # print(torch.abs(newweight).flatten().sort())

    while controlfunc == 0:
        # print('triggered')
        for indexval in torch.abs(newweight).flatten().sort()[1]:
            if indexval in pruned_index:
                # print('yes')
                controlfunc = 0
            else:
                prune_index_loc.append(indexval)
                pruned_index_val.append(torch.abs(newweight).flatten()[indexval].item())
                total_amnt = total_amnt+1
                controlfunc = 1
                if total_amnt >= num_values:
                    break
            
        if controlfunc == 0:
            print("All weights pruned? Something is wrong...")

        break

    # print(prune_index_loc)

    if prune_index_loc == -1:
        print("No index? Something is wrong...")
    # else:
    #     newmask.flatten()[prune_index_loc] = 0


    
    return prune_index_loc, pruned_index_val, controlfunc

def smallest_picker(list_of_masks, list_of_weights, layerlist, total=1):
    
    indicator = 0

    mask_indexes = []
    mask_values = []
    controlfuncs = []

    total_len = len(list_of_masks)
    counter = 1

    for mask, weight in zip(list_of_masks, list_of_weights):
        print(f"Processing {counter}/{total_len}", end='\r')
        tmpmask = mask.clone().to(device)
        tmpweight = weight.clone().to(device)

        prunelocs, prunevals, controlfunc = find_smallest_in_layer(tmpmask, tmpweight, total)
        mask_indexes.append(prunelocs)
        mask_values.append(prunevals)
        controlfuncs.append(controlfunc)
        counter += 1
        del prunelocs
        del prunevals
        del controlfunc
        del tmpmask
        del tmpweight

    masks = []

    for i in range(0, len(list_of_masks)):
        masks.append([])

    if np.sum(controlfuncs) == 0:
        print("No masks!")
        indicator = 1

    else:
        counter = 1
        for i in range(0, total):
            print(f"Processing {counter}/{total} Else Statement Outer        ", end='\r')
            first_column = []
            internal_counter = 1
            for j in range(0, len(list_of_masks)):
                # print(f"Processing {internal_counter}/{len(list_of_masks)} Else Statement inner", end='\r')
                try:
                    first_column.append(mask_values[j][0])
                except:
                    print("Exception: empty list occured!")
                    first_column.append(torch.tensor(9999999).to(device))
                internal_counter += 1
            # first_column_array = np.array(first_column)
            first_column_array = torch.tensor(first_column).to(device)
            # print(first_column_array)
            # min_index = np.argmin(first_column_array)
            min_index = torch.argmin(first_column_array)
            try:
                masks[min_index].append(mask_indexes[min_index].pop(0).item())
            except:
                print("Exception: empty list occured!")
                indicator = 1
                 
            # print(mask_values[min_index])
            try:
                mask_values[min_index].pop(0)
            except:
                print("Exception: empty list occured!")
                indicator = 1
            counter += 1

    for layer, mask in zip(layerlist, masks):
        layer.mask.flatten()[mask] = 0

    return masks, indicator

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(flatten_model(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(flatten_model(n))
        except:
            ret.append(modules)
    return flatten_list(ret)

def get_mask_layers(model):
    target_layers =[]
    module_list = [module for module in model.modules()] # this is needed
    flatted_list= flatten_model(module_list)

    for count, value in enumerate(flatted_list):
        
        if isinstance(value, (CustomConv2d, CustomLinear)):
        #if isinstance(value, (nn.Conv2d)):
            # print(count, value)
            target_layers.append(value)
            # target_layers.append(value.weight)
    return target_layers

def get_batchnorm_layers(model):
    target_layers =[]
    module_list = [module for module in model.modules()] # this is needed
    flatted_list= flatten_model(module_list)

    for count, value in enumerate(flatted_list):
        
        if isinstance(value, (nn.BatchNorm2d)):
        #if isinstance(value, (nn.Conv2d)):
            # print(count, value)
            target_layers.append(value)
            # target_layers.append(value.weight)
    return target_layers

def save_weights(layerlist):
    weightlist = []
    biaslist = []
    for layer in layerlist:
        weight = layer.weight.data.clone().to('cpu')
        bias = layer.bias.data.clone().to('cpu')
        weightlist.append(weight)
        biaslist.append(bias)
    return weightlist, biaslist

def save_batchnorms(layerlist):
    weightlist = []
    biaslist = []
    for layer in layerlist:
        weight = layer.weight.data.clone().to('cpu')
        bias = layer.bias.data.clone().to('cpu')
        weightlist.append(weight)
        biaslist.append(bias)
    return weightlist, biaslist

def save_mask(layerlist):
    masklist = []
    for layer in layerlist:
        mask = layer.mask.data.clone().to('cpu')
        masklist.append(mask)
    return masklist

def save_mask_to_disk(masklist, masksavename):
    list_of_masks = []
    for mask in masklist:
        npmask = mask.cpu().detach().numpy()
        list_of_masks.append(npmask)
    npmaskarray = np.array(list_of_masks, dtype=object)
    np.save(masksavename, npmaskarray)
    del list_of_masks
    del npmaskarray


def save_weights_to_disk(weightlist, biaslist, weightsavename, biassavename):
    list_of_weights = []
    list_of_biases = []
    for weight in weightlist:
        npweight = weight.cpu().detach().numpy()
        list_of_weights.append(npweight)
    for bias in biaslist:
        npbias = bias.cpu().detach().numpy()
        list_of_biases.append(npbias)
    npweightarray = np.array(list_of_weights, dtype=object)
    npbiasarray = np.array(list_of_biases, dtype=object)
    np.save(weightsavename, npweightarray)
    np.save(biassavename, npbiasarray)
    del npweightarray
    del list_of_weights
    del list_of_biases
    del npbiasarray

def save_batchnorms_to_disk(weightlist, biaslist, weightsavename, biassavename):
    list_of_weights = []
    list_of_biases = []
    for weight in weightlist:
        npweight = weight.cpu().detach().numpy()
        list_of_weights.append(npweight)
    for bias in biaslist:
        npbias = bias.cpu().detach().numpy()
        list_of_biases.append(npbias)
    npweightarray = np.array(list_of_weights, dtype=object)
    npbiasarray = np.array(list_of_biases, dtype=object)
    np.save(weightsavename, npweightarray)
    np.save(biassavename, npbiasarray)
    del npweightarray
    del list_of_weights
    del list_of_biases
    del npbiasarray


def reset_weights(layerlist, weightlist, biaslist):
    for layer, weighttensor, biastensor in zip(layerlist, weightlist, biaslist):
        layer.weight.data = weighttensor.clone().to(device)
        layer.bias.data = biastensor.clone().to(device)

def reset_batchnorm_weights(layerlist, weightlist, biaslist):
    for layer, weighttensor, biastensor in zip(layerlist, weightlist, biaslist):
        layer.weight.data = weighttensor.clone().to(device)
        layer.bias.data = biastensor.clone().to(device)

def load_weights_from_disk(layerlist, weightsavename, biassavename):
    disk_weights = np.load(weightsavename, allow_pickle=True)
    disk_biases = np.load(biassavename, allow_pickle=True)
    for layer, disk_weight, disk_bias in zip(layerlist, disk_weights, disk_biases):
        layer.weight.data = torch.from_numpy(disk_weight).clone().to(device)
        layer.bias.data = torch.from_numpy(disk_bias).clone().to(device)
    del disk_weights
    del disk_biases

def load_batchnorm_weights_from_disk(layerlist, weightsavename, biassavename):
    disk_weights = np.load(weightsavename, allow_pickle=True)
    disk_biases = np.load(biassavename, allow_pickle=True)
    for layer, disk_weight, disk_bias in zip(layerlist, disk_weights, disk_biases):
        layer.weight.data = torch.from_numpy(disk_weight).clone().to(device)
        layer.bias.data = torch.from_numpy(disk_bias).clone().to(device)
    del disk_weights
    del disk_biases

def load_mask(layerlist, masklist):
    for layer, mask in zip(layerlist, masklist):
        layer.mask.data = mask.clone().to(device)

def load_mask_from_disk(layerlist, masksavename):
    disk_mask_array = np.load(masksavename, allow_pickle=True)
    for layer, disk_mask in zip(layerlist, disk_mask_array):
        layer.mask.data = torch.from_numpy(disk_mask).clone().to(device)
    del disk_mask_array

def mask_size(layerlist):
    size = 0
    for layer in layerlist:
        size += torch.sum(layer.mask.data)
    return size


### Custom layers to force the mask to be applied during training

class CustomLinear(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs):                                                        
        super().__init__()                                                           
        self.register_buffer("mask", mask)                                         
        k = np.sqrt(1/inputs)
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs, inputs)), dtype=torch.float32))    ### Can change to other distributions as desired
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask) 
        out = F.linear(x, weight, self.bias)                                         
        return out               


class CustomConv2d(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, stride=1, padding=0, bias=True):                                                        
        super().__init__()                                                             
        self.register_buffer("mask", mask)     
        self.padding = padding 
        self.stride = stride
        self.bias_truthyness = bias                                
        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) 
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size = (outputs, inputs, kernalheight, kernelwidth)), dtype=torch.float32))                                                                                    
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):
        # print(self.weight.shape)
        # print(self.mask.shape)                                                  
        weight = torch.mul(self.weight, self.mask)
        if self.bias_truthyness == True:
            out = F.conv2d(x, weight, self.bias, stride=self.stride, padding=self.padding)
            return out
        elif self.bias_truthyness == False:
            out = F.conv2d(x, weight, bias=None, stride=self.stride, padding=self.padding)
            return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last

        self.conv1mask = torch.ones((planes, in_planes, 3, 3), requires_grad=False)
        self.conv2mask = torch.ones((planes, planes, 3, 3), requires_grad=False)

        self.jicmask = torch.ones((self.expansion * planes, in_planes, 1, 1), requires_grad=False)

        self.conv1 = CustomConv2d(self.conv1mask, in_planes, planes, kernalheight=3, kernelwidth=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = CustomConv2d(self.conv2mask, planes, planes, kernalheight=3, kernelwidth=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)


        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                CustomConv2d(self.jicmask, in_planes, self.expansion * planes, kernalheight=1, kernelwidth=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                None
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        # print('Made1')
        out = self.layer2(out)
        # print('Made2')
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out

### Define the network

def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

model_dict = {
    'resnet18': [resnet18, 512]
    # 'resnet34': [resnet34, 512],
    # 'resnet50': [resnet50, 2048],
    # 'resnet101': [resnet101, 2048],
}

class ResNetModel(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet18', head='mlp', feat_dim=128):
        super(ResNetModel, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        if head == 'linear':
            lin_mask = torch.ones((feat_dim, dim_in), requires_grad=False)
            self.head = CustomLinear(lin_mask, dim_in, feat_dim)
        elif head == 'mlp':
            lin_mask_1 = torch.ones((dim_in, dim_in), requires_grad=False)
            lin_mask_2 = torch.ones((feat_dim, dim_in), requires_grad=False)
            self.head = nn.Sequential(
                CustomLinear(lin_mask_1, dim_in, dim_in),
                nn.ReLU(inplace=True),
                CustomLinear(lin_mask_2, dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        return feat

### Train+test the model

def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = 0
            loss = 0
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}", end='\r')

def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = 0

            outputs = model(X_batch)

            predictions.extend(outputs.cpu().numpy())
            actuals.extend(y_batch.cpu().numpy())

    return np.array(predictions), np.array(actuals)


def lotto_ticket_training(train_dl, test_dl, num_epochs, percent_pruned = .1, rounds_of_pruning = 20): ### Round of pruning is n from p^(1/n)%, so 20 rounds of pruning will remove p^(1/20)% of the weights each round. Percent pruned is our final target pruning percentage, so .1 would be prune to 10% of the original size. Num epochs tells us how long we should train after each round of pruning. 

    lotto_model = ResNetModel(name='resnet18', head='mlp', feat_dim=128)
    lotto_model.to(device)
    optimizer = optim.Adam(lotto_model.parameters())
    criterion = nn.CrossEntropyLoss()

    pruning_percentages = percent_pruned**(1/rounds_of_pruning)

    target_layers = get_mask_layers(lotto_model)
    batchnorm_layers = get_batchnorm_layers(lotto_model)

    orig_weights, orig_bias = save_weights(target_layers)
    orig_batchnorm_weights, orig_batchnorm_bias = save_batchnorms(batchnorm_layers)
    orig_masks = save_mask(target_layers)

    save_weights_to_disk(orig_weights, orig_bias, f"{directory_name}_initial_init_weights.npy", f"{directory_name}_initial_init_bias.npy")
    save_batchnorms_to_disk(orig_batchnorm_weights, orig_batchnorm_bias, f"{directory_name}_initial_batchnorm_init_weights.npy", f"{directory_name}_initial_batchnorm_init_bias.npy")

    save_mask_to_disk(orig_masks, f"{directory_name}_mask.npy")

    # lotto_model.save_init_weights()
    
    # orig_weights = [lotto_model.fc1.weight.data.clone(), lotto_model.fc2.weight.data.clone(), lotto_model.fc3.weight.data.clone()]
    # orig_masks = [lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]
    # orig_bias = [lotto_model.fc1.bias.data.clone(), lotto_model.fc2.bias.data.clone(), lotto_model.fc3.bias.data.clone()]

    # initweights = np.array([orig_weights[0].cpu().detach().numpy(), orig_weights[1].cpu().detach().numpy(), orig_weights[2].cpu().detach().numpy()], dtype=object)
    # np.save(f"{directory_name}_initial_init_weights.npy", initweights)

    # initbias = np.array([orig_bias[0].cpu().detach().numpy(), orig_bias[1].cpu().detach().numpy(), orig_bias[2].cpu().detach().numpy()], dtype=object)
    # np.save(f"{directory_name}_initial_init_bias.npy", initbias)

    # numpymask = np.array([orig_masks[0].cpu().detach().numpy(), orig_masks[1].cpu().detach().numpy(), orig_masks[2].cpu().detach().numpy()], dtype=object)
    # np.save(f"{directory_name}_mask.npy", numpymask)


    losses = []

    # tmpmask = [lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]

    tmpmask = save_mask(target_layers)

    # orig_size = torch.sum(orig_masks[0])+torch.sum(orig_masks[1])+torch.sum(orig_masks[2])
    orig_size = mask_size(target_layers)

    orig_loss = 0

    orig_loss_marker = 0

    orig_accuracy = 0


    stopping_criterion = 0
    prune_amnt = round((1.0-pruning_percentages)*orig_size.item())

    del target_layers
    del batchnorm_layers
    gc.collect()



    for countingstuff in range(rounds_of_pruning):
        del lotto_model

        torch.cuda.empty_cache()

        lotto_model = ResNetModel(name='resnet18', head='mlp', feat_dim=128)

        batchweightloc = f"{directory_name}_initial_batchnorm_init_weights.npy"
        batchbiasloc = f"{directory_name}_initial_batchnorm_init_bias.npy"
        wieghtloc = f"{directory_name}_initial_init_weights.npy"
        biasloc = f"{directory_name}_initial_init_bias.npy"
        maskloc = f"{directory_name}_mask.npy"

        # weightstuff = np.load(wieghtloc, allow_pickle=True)
        # biasstuff = np.load(biasloc, allow_pickle=True)
        # maskstuff = np.load(maskloc, allow_pickle=True)

        # lotto_model.load_init_weights(weightstuff, biasstuff)
        # lotto_model.reset_weights()
        target_layers = get_mask_layers(lotto_model)
        batchnorm_layers = get_batchnorm_layers(lotto_model)

        ### NEED TO LOAD WEIGHTS/BIASES FROM NON-MASKED LAYERS TOO!!!

        load_weights_from_disk(target_layers, wieghtloc, biasloc)
        load_batchnorm_weights_from_disk(batchnorm_layers, batchweightloc, batchbiasloc)

        # lotto_model.load_mask(maskstuff)

        load_mask_from_disk(target_layers, maskloc)

        print(f"Model Size: {mask_size(target_layers)}")


        lotto_model.to(device)

        model_optimizer = optim.Adam(lotto_model.parameters())
        model_criterion = nn.CrossEntropyLoss()
        lotto_model.to(device)


        train_model(lotto_model, train_dl, model_criterion, model_optimizer, num_epochs)
        pre_prune_test_res = evaluate_model(lotto_model, test_dl)
        pre_prune_test_loss = model_criterion(torch.tensor(pre_prune_test_res[0]), torch.tensor(pre_prune_test_res[1]))
        pre_prune_test_acc = torch.sum(torch.max(torch.tensor(pre_prune_test_res[0]), dim=1)[1] == torch.tensor(pre_prune_test_res[1]))/len(pre_prune_test_res[1])

        # weights = [lotto_model.fc1.weight.data, lotto_model.fc2.weight.data, lotto_model.fc3.weight.data]
        weights, biases = save_weights(target_layers)
        # masks = [lotto_model.mask_fc1.clone(), lotto_model.mask_fc2.clone(), lotto_model.mask_fc3.clone()]
        masks = save_mask(target_layers)
        # tmpsize = torch.sum(masks[0]) + torch.sum(masks[1]) + torch.sum(masks[2])
        tmpsize = mask_size(target_layers)

        prune_amnt = round((1.0-pruning_percentages)*tmpsize.item())

        remove_mask_indeces, indicator = smallest_picker(masks, weights, target_layers, total=prune_amnt)

        print(f"New Model Size: {mask_size(target_layers)}")

        if orig_loss_marker == 0:
            orig_loss = pre_prune_test_loss
            orig_accuracy = pre_prune_test_acc
            print("Loss of the Original Network: {}, Accuracy of Original Network: {}".format(orig_loss, orig_accuracy))
            orig_loss_marker = 1

            # lotto_model.mask_fc1.flatten()[remove_mask_indeces[0]] = 0
            # lotto_model.mask_fc2.flatten()[remove_mask_indeces[1]] = 0
            # lotto_model.mask_fc3.flatten()[remove_mask_indeces[2]] = 0

            # lotto_model.fc1.mask.flatten()[remove_mask_indeces[0]] = 0
            # lotto_model.fc2.mask.flatten()[remove_mask_indeces[1]] = 0
            # lotto_model.fc3.mask.flatten()[remove_mask_indeces[2]] = 0
        
        else:
            # lotto_model.mask_fc1.flatten()[remove_mask_indeces[0]] = 0
            # lotto_model.mask_fc2.flatten()[remove_mask_indeces[1]] = 0
            # lotto_model.mask_fc3.flatten()[remove_mask_indeces[2]] = 0

            # lotto_model.fc1.mask.flatten()[remove_mask_indeces[0]] = 0
            # lotto_model.fc2.mask.flatten()[remove_mask_indeces[1]] = 0
            # lotto_model.fc3.mask.flatten()[remove_mask_indeces[2]] = 0
            None

        post_prune = evaluate_model(lotto_model, test_dl)
        post_prune_loss = model_criterion(torch.tensor(post_prune[0]), torch.tensor(post_prune[1]))
        post_prune_acc = torch.sum(torch.max(torch.tensor(post_prune[0]), dim=1)[1] == torch.tensor(post_prune[1]))/len(post_prune[1])



        newsize = 0

        if indicator == 1:
            print("Stuff is broken!") ### This shouldn't ever trigger, but in rare cases in extreme pruning amounts (think, pruning 999 weights from a 1000 weight network) will trigger since it will try to remove the last weight from a layer. One could concievably put in a check into the pruning function to ensure it never prunes in such a way as to result in 0 weights, but this was not a problem in our experiments, and this solution is much more complex than necessary.
            stopping_criterion = 1
            break
        
        else:
            # newsize = torch.sum(masks[0])+torch.sum(masks[1])+torch.sum(masks[2])
            newsize = mask_size(target_layers)
            losses.append([post_prune_loss, newsize])
            sizer = str(round(((newsize.item()/orig_size.item()))*1000)/10)
            
            # savemask = np.array([lotto_model.mask_fc1.cpu().detach().numpy(), lotto_model.mask_fc2.cpu().detach().numpy(), lotto_model.mask_fc3.cpu().detach().numpy()], dtype=object)
            masklist = save_mask(target_layers)

            save_mask_to_disk(masklist, f"{directory_name}_mask.npy")
            save_mask_to_disk(masklist, f"{directory_name}_various_masks/mask_{sizer}_size.npy")
            
            # np.save(f"{directory_name}_mask.npy", savemask)
            # np.save(f"{directory_name}_various_masks/mask_{sizer}_size.npy", savemask) ### Save intermediate masks. Useful if the final one is dumb and doesn't work. Can happen if you prune too much. 


        if type(newsize) == int:
            percentagino = round(((newsize/orig_size.item()))*1000)/10
        else:
            percentagino = round(((newsize.item()/orig_size.item()))*1000)/10


        print("Model pruned to {}%. New Loss = {}, Orig Loss = {}, New Accuracy = {}, Original Accuracy = {}".format(percentagino, post_prune_loss, orig_loss, post_prune_acc, orig_accuracy))

        # lotto_model.reset_weights()

        ### All important things are saved to disk. The rest are deleted to ensure caches are cleared and no interference occurs. 
        del tmpsize
        del newsize
        del pre_prune_test_loss
        del model_optimizer
        del model_criterion
        # del maskstuff
        # del weightstuff
        # del biasstuff
        del weights
        del biases
        del masks
        del target_layers
        del batchnorm_layers
        # del tmpsize

        
    
    return tmpmask, orig_weights, orig_bias, losses





### Run the thing yeah!
mask, weights, bias, losses = lotto_ticket_training(train_loader, test_loader, num_epochs=200, percent_pruned = .001, rounds_of_pruning = 50)

### Get that mask, the lottery ticket initialization, and save it to disk, yeah!
# numpymask = np.array([mask[0].cpu().detach().numpy(), mask[1].cpu().detach().numpy(), mask[2].cpu().detach().numpy()], dtype=object)

# initweights = np.array([weights[0].cpu().detach().numpy(), weights[1].cpu().detach().numpy(), weights[2].cpu().detach().numpy()], dtype=object)

# initbias = np.array([bias[0].cpu().detach().numpy(), bias[1].cpu().detach().numpy(), bias[2].cpu().detach().numpy()], dtype=object)

# np.save(f"{directory_name}_mask.npy", numpymask)

# np.save(f"{directory_name}_initial_weights.npy", numpymask)

# np.save(f"{directory_name}_initial_bias.npy", initbias)

# lossarray = np.array(losses)

# np.save(f"{directory_name}_losses.npy", lossarray)

### Perfecto - one piping hot lottery ticket mask ready and saved to disk (along with a bunch of other masks at various pruning percentages)



