# An abostraction of the Torch NN nueral network class.
# We will use it to define our own networks if we need and
# just provides a nice little abstraction layer for us.

import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torchvision
from torch import nn
import os
import matplotlib.pyplot as plt
import numpy as np
import copy
import torch.nn.functional as F


minst_training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

minst_test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#mergeDataSet = torch.utils.data.ConcatDataset([minst_training_data, minst_test_data])
# minst_training_data , minst_test_data = torch.utils.data.random_split(minst_training_data, 
#                                                                       [20000, 40000], 
#                                                                       generator=torch.Generator().manual_seed(42))

def generateTestDataLoader(dataset, batch_size=64):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

def twoClass(number_of_nodes,training_data, batch_size=10,class_num=2):
    node_to_label_idx={}
    for i in range(number_of_nodes):
        node_to_label_idx[i]=[]
    
    maxi=len(training_data)/number_of_nodes
    for idx, target in enumerate(training_data.targets):
        #print(target)
        for i in range(number_of_nodes):
            idx_added= False
            if target%class_num == i%class_num:
                if len(node_to_label_idx[i])<maxi:
                    node_to_label_idx[i].append(idx)
                    idx_added=True
                elif (idx_added==False):
                    node_to_label_idx[np.random.randint(0, number_of_nodes)]

    
    trainloaders_dict = {}
    for i in range(number_of_nodes):
        trainset = torch.utils.data.Subset(training_data, node_to_label_idx[i])
        trainloaders_dict[i] = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    
    return trainloaders_dict

    

def generateTrainDataloadersNonIid(number_of_nodes, training_data, batch_size=10):
    node_to_label_idx={}

    for i in range(number_of_nodes):
        node_to_label_idx[i]=[]
                    
    for idx, target in enumerate(training_data.targets):
        for i in range(number_of_nodes):
            if target == i%10:
                node_to_label_idx[i].append(idx)

    trainloaders_dict = {}
    for i in range(number_of_nodes):
        trainset = torch.utils.data.Subset(training_data, node_to_label_idx[i])
        trainloaders_dict[i] = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    
    return trainloaders_dict


def generateTrainDataloaders(number_of_nodes, training_data, batch_size=64):
    """
    Given Training Data and the number of nodes to seperate
    the training data into, returns a dictionary containing
    those nodes.

    Input: number_of_nodes- number of subsets
    training_data- training dataset
    batchsize- the number of smaples for gd

    Returns:  trainloaders_dict- repersenting a dictionary with
    each dataloader, indexed by integers.
    """
    indices = np.arange(0, len(training_data))
    np.random.shuffle(indices)
    trainloaders_dict = {}
    step = len(indices)//number_of_nodes

    for i in range(0, number_of_nodes):
        index = indices[i*step:i*step+step]
        trainset = torch.utils.data.Subset(training_data, index)
        print(len(trainset))
        trainloaders_dict[i] = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    return trainloaders_dict

def l1(model, l1_lambda=.001, device='cuda'):
    """
    Returns the regularizer to be added to the loss function.

    Input: model- neural network that the loss function is computed on (NeuralNetwork object)
    l1_lambda- parameters for l1 regularization (float)

    Returns: regularized l1 norm (torch.tensor)
    """
    norm = torch.tensor(0.,device=device)
    

    for p in model.parameters():
        #has_nan = torch.isnan(p).any()
        #if has_nan:
        #    raise Exception("Model is NAN")
        norm+=torch.norm(p,1)

    return l1_lambda*norm


def trainLoopWeights(dataloader, model, optimizer, loss_fn = nn.CrossEntropyLoss(), id=None, quiet=True):
    """
    Performs training on model.

    Input: dataloader- mapping from nodes to data (dictionary)
    model- model being trained (NeuralNetwork() object)
    optimizer- algorithm used to train (torch.optim object)
    loss_fn- loss function that the optimizer optimizes over
    id- node identification (integer)
    quiet- boolean that determines whether or not to print loss for added weights (boolean)

    Returns: model- updated model (NeuralNetwork object)
    weights- weights commputed during current run (list)

    """
    size = len(dataloader.dataset)
    weights=[]
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        #l1_lambda = sum([torch.linalg.norm(w, 1) for w in model.state_dict()])
        loss = loss + l1(model)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #print(batch)#937
        if batch % 25 == 0:
            loss, current = loss.item(), batch * len(X)
            if not quiet:
                print(f"Node number: {id} loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            current_weights = model.state_dict()
            weights.append(copy.deepcopy(current_weights))

    return model, weights


def trainLoop(dataloader, model, optimizer, loss_fn = nn.CrossEntropyLoss(), id=None, reg="l1", device='cpu', E=4, clip=100):
    """
    Performs training on model.

    Input: dataloader- mapping from nodes to data (dictionary)
    model- model being trained (NeuralNetwork() object)
    optimizer- algorithm used to train (torch.optim object)
    loss_fn- loss function that the optimizer optimizes over
    id- node identification (integer)

    Returns: model- updated model (NeuralNetwork() object)
    """
    for e in range(E):
        size = len(dataloader.dataset)
        step = 0 
        for batch, (X, y) in enumerate(dataloader):
            # print("step: ",step)
            # Compute prediction and loss
            pred = model(X.to(device))
            # print("pred done")
            loss = loss_fn(pred, y.to(device))

            # if reg=="l1":
            #     print("wrong reg")
            #     loss+= l1(model)

            # Backpropagation
            optimizer.zero_grad()
            # print("zero grad done")
            loss.backward()
            # print("backward done ")
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip, norm_type='inf')
            optimizer.step()
            # print("optimzer done")
            #print("Something\n")
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                #print(f"Node number: {id} loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            step+=1

    
    norm_val=l1(model, l1_lambda=reg, device=device)
    has_nan = torch.isnan(norm_val).any()
    #if has_nan:
    #    raise Exception("l1 produced NAN")

    return model, (loss+norm_val).item()


def testLoop(dataloader, model, loss_fn = nn.CrossEntropyLoss(), queit=False, device='mps'):
    """
    Tests a given model.

    Input: dataloader- mapping from nodes to data (dictionary)
    model- model being trained (NeuralNetwork object)
    loss_fn- loss function that the optimizer optimizes over
    quiet- boolean indiciating if accuracy should be printed (bolean)

    Returns: model- updated model (NeuralNetwork object)
    """
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    model.to(device)
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X.to(device))
            test_loss += loss_fn(pred, y.to(device)).to(device).item()
            correct += (pred.argmax(1).to(device) == y.to(device)).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    if not queit:
      print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return 100*correct, test_loss

def evaluteNetwork(train_dataloader, test_dataloder, model, optimizer, loss_fn = nn.CrossEntropyLoss(), id=None, reg="", epochs=1):
    """
    Evaluates a network over multiple epochs.

    Input: train_dataloader- mapping from nodes to training data (dictionary)
    test_dataloader- mapping from nodes to training data (dictionary)
    model- model being trained (NeuralNetwork object)
    optimizer- algorithm used to train (torch.optim object)
    loss_fn- loss function that the optimizer optimizes over
    id- node identification (integer)
    reg- indicates what kind of regularization to use, if any (e.g. "l1") (string)
    epochs- number of epochs to evaluate over (integer)

    Returns: None
    """
    # need to repeat for number of epochs
    i=0
    while(i<epochs):
        print("epoch number: ", i)
        trainLoop(train_dataloader, model, optimizer, loss_fn = loss_fn, id=None, reg=reg)
        i+=1

    testLoop(test_dataloder,model,loss_fn,queit=False)



class LightMnist(nn.Module):
    def __init__(self):
        super(LightMnist, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(784, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


class FullMnist(nn.Module):
    def __init__(self):
        super(FullMnist, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            # nn.Linear(256, 128),
            # nn.ReLU(),
            # nn.Linear(128,64),
            # nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


class CNNMnist(nn.Module):
    def __init__(self, num_channels=3,num_classes=10,batch_norm=True):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 50, kernel_size=5)
        self.conv2 = nn.Conv2d(50, 50, kernel_size=5)

        if batch_norm:
            self.conv2_norm=nn.BatchNorm2d(50)
        else:
            self.conv2_norm = nn.Dropout2d()
        self.fc1 = nn.Linear(1250, 500)
        self.fc2 = nn.Linear(500, num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_norm(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x
    
# class ResNet(nn.mMdule):
#     def

# class L1Regularizer():
#     def __init__(self,l1_lambda=""):
#         self.l1_lambda=.01
#         if l1_lambda ==
