import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader

epochs = 200 #How many epochs to train for
num_trials = 50 #How many times you want to run this in a loop
# mask = np.load('tests/from_labserv/newstuff/CNN_LeNet5_CIFAR/99_test1_various_masks/mask_17.4_size.npy', allow_pickle=True) #Put in the location to the .npy file
# maskloc = "tests/Resnet18_CIFAR/99_test1_various_masks/mask_1.0_size.npy"
maskloc = "tests/from_labserv/newstuff/Resnet18/99_test1_various_masks/mask_8.3_size.npy"



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" #Uncomment if u want CPU for whatever godforsaken reason...
# device

batchsize = 1024

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.CIFAR10('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = batchsize)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = batchsize)
  

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(flatten_model(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(flatten_model(n))
        except:
            ret.append(modules)
    return flatten_list(ret)

def get_mask_layers(model):
    target_layers =[]
    module_list = [module for module in model.modules()] # this is needed
    flatted_list= flatten_model(module_list)

    for count, value in enumerate(flatted_list):
        
        if isinstance(value, (CustomConv2d, CustomLinear)):
        #if isinstance(value, (nn.Conv2d)):
            # print(count, value)
            target_layers.append(value)
            # target_layers.append(value.weight)
    return target_layers

def get_batchnorm_layers(model):
    target_layers =[]
    module_list = [module for module in model.modules()] # this is needed
    flatted_list= flatten_model(module_list)

    for count, value in enumerate(flatted_list):
        
        if isinstance(value, (nn.BatchNorm2d)):
        #if isinstance(value, (nn.Conv2d)):
            # print(count, value)
            target_layers.append(value)
            # target_layers.append(value.weight)
    return target_layers

def save_weights(layerlist):
    weightlist = []
    biaslist = []
    for layer in layerlist:
        weight = layer.weight.data.clone().to('cpu')
        bias = layer.bias.data.clone().to('cpu')
        weightlist.append(weight)
        biaslist.append(bias)
    return weightlist, biaslist

def save_batchnorms(layerlist):
    weightlist = []
    biaslist = []
    for layer in layerlist:
        weight = layer.weight.data.clone().to('cpu')
        bias = layer.bias.data.clone().to('cpu')
        weightlist.append(weight)
        biaslist.append(bias)
    return weightlist, biaslist

def save_mask(layerlist):
    masklist = []
    for layer in layerlist:
        mask = layer.mask.data.clone().to('cpu')
        masklist.append(mask)
    return masklist

def save_mask_to_disk(masklist, masksavename):
    list_of_masks = []
    for mask in masklist:
        npmask = mask.cpu().detach().numpy()
        list_of_masks.append(npmask)
    npmaskarray = np.array(list_of_masks, dtype=object)
    np.save(masksavename, npmaskarray)
    del list_of_masks
    del npmaskarray


def save_weights_to_disk(weightlist, biaslist, weightsavename, biassavename):
    list_of_weights = []
    list_of_biases = []
    for weight in weightlist:
        npweight = weight.cpu().detach().numpy()
        list_of_weights.append(npweight)
    for bias in biaslist:
        npbias = bias.cpu().detach().numpy()
        list_of_biases.append(npbias)
    npweightarray = np.array(list_of_weights, dtype=object)
    npbiasarray = np.array(list_of_biases, dtype=object)
    np.save(weightsavename, npweightarray)
    np.save(biassavename, npbiasarray)
    del npweightarray
    del list_of_weights
    del list_of_biases
    del npbiasarray

def save_batchnorms_to_disk(weightlist, biaslist, weightsavename, biassavename):
    list_of_weights = []
    list_of_biases = []
    for weight in weightlist:
        npweight = weight.cpu().detach().numpy()
        list_of_weights.append(npweight)
    for bias in biaslist:
        npbias = bias.cpu().detach().numpy()
        list_of_biases.append(npbias)
    npweightarray = np.array(list_of_weights, dtype=object)
    npbiasarray = np.array(list_of_biases, dtype=object)
    np.save(weightsavename, npweightarray)
    np.save(biassavename, npbiasarray)
    del npweightarray
    del list_of_weights
    del list_of_biases
    del npbiasarray


def reset_weights(layerlist, weightlist, biaslist):
    for layer, weighttensor, biastensor in zip(layerlist, weightlist, biaslist):
        layer.weight.data = weighttensor.clone().to(device)
        layer.bias.data = biastensor.clone().to(device)

def reset_batchnorm_weights(layerlist, weightlist, biaslist):
    for layer, weighttensor, biastensor in zip(layerlist, weightlist, biaslist):
        layer.weight.data = weighttensor.clone().to(device)
        layer.bias.data = biastensor.clone().to(device)

def load_weights_from_disk(layerlist, weightsavename, biassavename):
    disk_weights = np.load(weightsavename, allow_pickle=True)
    disk_biases = np.load(biassavename, allow_pickle=True)
    for layer, disk_weight, disk_bias in zip(layerlist, disk_weights, disk_biases):
        layer.weight.data = torch.from_numpy(disk_weight).clone().to(device)
        layer.bias.data = torch.from_numpy(disk_bias).clone().to(device)
    del disk_weights
    del disk_biases

def load_batchnorm_weights_from_disk(layerlist, weightsavename, biassavename):
    disk_weights = np.load(weightsavename, allow_pickle=True)
    disk_biases = np.load(biassavename, allow_pickle=True)
    for layer, disk_weight, disk_bias in zip(layerlist, disk_weights, disk_biases):
        layer.weight.data = torch.from_numpy(disk_weight).clone().to(device)
        layer.bias.data = torch.from_numpy(disk_bias).clone().to(device)
    del disk_weights
    del disk_biases

def load_mask(layerlist, masklist):
    for layer, mask in zip(layerlist, masklist):
        layer.mask.data = mask.clone().to(device)

def load_mask_from_disk(layerlist, masksavename):
    disk_mask_array = np.load(masksavename, allow_pickle=True)
    for layer, disk_mask in zip(layerlist, disk_mask_array):
        layer.mask.data = torch.from_numpy(disk_mask).clone().to(device)
    del disk_mask_array

def mask_size(layerlist):
    size = 0
    for layer in layerlist:
        size += torch.sum(layer.mask.data)
    return size 
    

class CustomLinear(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs):                                                        
        super().__init__()                                                           
        self.register_buffer("mask", mask)                                         
        k = np.sqrt(1/inputs)
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs, inputs)), dtype=torch.float32))    ### Can change to other distributions as desired
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):                                                            
        weight = torch.mul(self.weight, self.mask) 
        out = F.linear(x, weight, self.bias)                                         
        return out               


class CustomConv2d(nn.Module):                                                       
    def __init__(self, mask, inputs, outputs, kernalheight, kernelwidth, stride=1, padding=0, bias=True):                                                        
        super().__init__()                                                             
        self.register_buffer("mask", mask)     
        self.padding = padding 
        self.stride = stride
        self.bias_truthyness = bias                                
        k = np.sqrt(1/(inputs*kernalheight*kernelwidth)) 
        self.weight = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size = (outputs, inputs, kernalheight, kernelwidth)), dtype=torch.float32))                                                                                    
        self.bias = nn.Parameter(torch.tensor(np.random.uniform(-1*k, k, size=(outputs)), dtype=torch.float32))                         
                                                                                     
    def forward(self, x):
        # print(self.weight.shape)
        # print(self.mask.shape)                                                  
        weight = torch.mul(self.weight, self.mask)
        if self.bias_truthyness == True:
            out = F.conv2d(x, weight, self.bias, stride=self.stride, padding=self.padding)
            return out
        elif self.bias_truthyness == False:
            out = F.conv2d(x, weight, bias=None, stride=self.stride, padding=self.padding)
            return out
    


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last

        self.conv1mask = torch.ones((planes, in_planes, 3, 3), requires_grad=False)
        self.conv2mask = torch.ones((planes, planes, 3, 3), requires_grad=False)

        self.jicmask = torch.ones((self.expansion * planes, in_planes, 1, 1), requires_grad=False)

        self.conv1 = CustomConv2d(self.conv1mask, in_planes, planes, kernalheight=3, kernelwidth=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = CustomConv2d(self.conv2mask, planes, planes, kernalheight=3, kernelwidth=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)


        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                CustomConv2d(self.jicmask, in_planes, self.expansion * planes, kernalheight=1, kernelwidth=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                None
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        # print('Made1')
        out = self.layer2(out)
        # print('Made2')
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out

### Define the network

def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

model_dict = {
    'resnet18': [resnet18, 512]
    # 'resnet34': [resnet34, 512],
    # 'resnet50': [resnet50, 2048],
    # 'resnet101': [resnet101, 2048],
}

class ResNetModel(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet18', head='mlp', feat_dim=128):
        super(ResNetModel, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        if head == 'linear':
            lin_mask = torch.ones((feat_dim, dim_in), requires_grad=False)
            self.head = CustomLinear(lin_mask, dim_in, feat_dim)
        elif head == 'mlp':
            lin_mask_1 = torch.ones((dim_in, dim_in), requires_grad=False)
            lin_mask_2 = torch.ones((feat_dim, dim_in), requires_grad=False)
            self.head = nn.Sequential(
                CustomLinear(lin_mask_1, dim_in, dim_in),
                nn.ReLU(inplace=True),
                CustomLinear(lin_mask_2, dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        return feat
    


def train_model(model, train_loader, criterion, optimizer, num_epochs=1):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = 0
            loss = 0
            # if network_type == 'LSTM':
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            # else:
            #     print("No Model Selected. Please input either 'LSTM' if using LSTM models, or 'S4' for an S4 model.")
            #     return

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # if (epoch+1) % 25 == 0:
        # print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}", end="\r")

def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = 0

            # if network_type == 'LSTM':
            outputs = model(X_batch)
            # elif network_type == 'S4':
            #     outputs = model(X_batch.unsqueeze(-1))
            # else:
            #     print("No Model Selected. Please input either 'LSTM' if using LSTM models, or 'S4' for an S4 model.")
            #     return

            predictions.extend(outputs.cpu().numpy())
            actuals.extend(y_batch.cpu().numpy())

    return np.array(predictions), np.array(actuals)

modellosses = []
modelaccs = []

model = ResNetModel()



for i in range(0, num_trials):
    del model
    model = ResNetModel()
    mask_layers = get_mask_layers(model)
    load_mask_from_disk(mask_layers, maskloc)

    model.to(device)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    # model.load_mask(mask)
    # model.load_init_weights(init_weights, init_bias)
    # model.reset_weights()
    
    train_model(model, train_loader, criterion, optimizer, num_epochs = epochs)

    t1, t2 = evaluate_model(model, test_loader)

    accuracy = torch.sum(torch.max(torch.tensor(t1), dim=1)[1] == torch.tensor(t2))/len(t2)
    modelaccs.append(accuracy.item())
    loss = criterion(torch.tensor(t1), torch.tensor(t2))
    modellosses.append(loss.item())

    print(f"Trial: {i+1}/{num_trials}                                                        ")
    print(f"Accuracy: {accuracy.item()}")
    print(f"Loss: {loss.item()}")
    del t1
    del t2

print(f"{num_trials} Trials:  Avg Acc: {np.mean(modelaccs)}, Std Acc: {np.std(modelaccs)}, Avg. Loss: {np.mean(modellosses)}, Std Loss: {np.std(modellosses)}")