"""
Setup and Dataset import
"""

import numpy as np
import matplotlib.pyplot as plt
import math
import os, sys
import random
import copy
import string

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision
import torchvision.transforms as transforms
from typing import Callable, List, Optional, Tuple
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import MultiStepLR, StepLR
import torch.optim as optim

from tqdm.notebook import tqdm

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

"""Network Definition(Credit: https://github.com/erfunmirzaei/FedAvg-FedProx)"""

class Deep_CNN(nn.Module):
    """
    Implement the network architecture which was used in "Personalized Federated Learning: A Meta-Learning Approach" Paper.
    """
    def __init__(self, num_classes):
        super(Deep_CNN, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)

        # Second layer
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)

        # First Hidden layer
        self.fc1 = nn.Linear(16 * 5 *5, 120)

        # Second Hidden layer
        self.fc2 = nn.Linear(120, 84)

        # Classifier
        self.fc3 = nn.Linear(84,num_classes)

        # Activation Function
        self.relu = nn.ReLU()

        # Dropout
        self.dropout = nn.Dropout( p= 0.2)


    def forward(self,x):
        # First Layer
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)

        # Second Layer
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)

        x = torch.flatten(x, start_dim = 1)

        # First Hidden layer
        x = self.fc1(x)
        x = self.relu(x)

        # Second Hidden layer
        x = self.fc2(x)
        x = self.relu(x)

        # Classifier
        x = self.fc3(x)

        return x

"""NIRo Definition"""

def non_iid_metric(complete_array, k):
    # Gather single metric from class-specific vars
    per_node_data_counts = []
    for i in range(len(complete_array)):
        per_node_data_counts.append(np.sum(complete_array[i]))

    per_node_class_var = []
    for i in range(len(complete_array)):
        values_for_node = complete_array[i]
        maximum = []
        for l in range(k-1):
            maximum.append(0)
        maximum.append(per_node_data_counts[i] - np.sum(maximum))
        per_node_class_var.append(per_node_data_counts[i]/np.sum(per_node_data_counts) * (np.var(values_for_node)/np.var(maximum)))

    class_specific_val = np.sum(per_node_class_var)
    total_datapoints = np.sum(per_node_data_counts)
    max_var = []
    for i in range(len(per_node_data_counts) - 1):
        max_var.append(1)
    max_var.append(total_datapoints - np.sum(max_var))

    dataset_specific_val = np.var(per_node_data_counts)/np.var(max_var)
    return class_specific_val/(dataset_specific_val + class_specific_val) * class_specific_val + dataset_specific_val/(dataset_specific_val + class_specific_val)*dataset_specific_val

"""Data-split"""

# The data-generation section of this code-block is taken from NIID-Bench (https://github.com/Xtra-Computing/NIID-Bench/blob/main/partition.py), some portions of the training setup
# have been taken from https://github.com/erfunmirzaei/FedAvg-FedProx

transform2 = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

torch.manual_seed(14)

batch_size = 75

trainset2 = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform2)

testset2 = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform2)

testloader2 = torch.utils.data.DataLoader(testset2, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



idxs = np.random.permutation(len(trainset2.__dict__['targets']))
n_parties = 12
min_size = 0
K=10

# Uncomment for iid data-counts, but skewed data-labels
#while min_size < min_require_size:
#	    idx_batch = [[] for _ in range(n_parties)]
#	    for k in range(K):
#	        idx_k = np.where(y_train == k)[0]
#	        np.random.shuffle(idx_k)
#	        proportions = np.random.dirichlet(np.repeat(beta, n_parties))
#	        proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)])
#	        proportions = proportions / proportions.sum()
#	        proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
#	        idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
#	        min_size = min([len(idx_j) for idx_j in idx_batch])

#	for i in range(n_parties):
#		net_dataidx_map[i] = idx_batch[i]


while min_size < 1:
    proportions = np.random.dirichlet(np.repeat(0.07, n_parties))
    proportions = proportions/proportions.sum()
    min_size = np.min(proportions*len(idxs))
    #print(min_size)
proportions = (np.cumsum(proportions)*len(idxs)).astype(int)[:-1]
print(proportions)
batch_idxs = np.split(idxs,proportions)

net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}

for i in range(n_parties):
    net_dataidx_map[i] = net_dataidx_map[i].tolist()

#print(net_dataidx_map)

subsets = []
for i in range(n_parties):
  subsets.append(Subset(trainset2, net_dataidx_map[i]))

trainloaders2 = [torch.utils.data.DataLoader(subsets[i], batch_size=batch_size,
                                          shuffle=True, num_workers=2) for i in range(n_parties)]


complete_array = []
for i in range(n_parties):
    #print("For node: ", str(i))
    image_batches, label_batches = zip(*[batch for batch in trainloaders2[i]])
    label_dict = {i:0 for i in range(K)}
    per_party = []
    for j in range(len(label_batches)):
      label_batch = label_batches[j]
      for k in range(len(label_batch)):
        label = label_batch[k].item()
        label_dict[label] = label_dict[label] + 1
    for val in label_dict.keys():
      per_party.append(label_dict[val])
      #print("Label: ", str(val), " with points ", str(label_dict[val]))
    complete_array.append(per_party)
quant = non_iid_metric(complete_array, 10)
print(quant)

"""Protocol Definitions (Credit: https://github.com/erfunmirzaei/FedAvg-FedProx)"""

class FedAvg_ClientUpdate(object):

    def __init__(self, train_loader):

        self.train_loader =  train_loader
        # Define Loss function
        self.criterion = nn.CrossEntropyLoss()

    def update(self, net, max_epochs, lr):
        optimizer = optim.SGD(net.parameters(), lr = lr, momentum = 0.9)

        train_losses = []
        validation_losses = []

        train_accs = []
        val_accs = []

        # Loop over epochs
        for epoch in range(max_epochs):

            running_trainloss, train_acc = 0.0, 0.0
            train_cnt = 0

            # Training
            for train_data, train_labels in self.train_loader:
                # Transfer to GPU
                train_data, train_labels = train_data.to(device), train_labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # Model computations
                # forward + backward + optimize
                batch_train_outputs = net(train_data)

                loss = self.criterion(batch_train_outputs, train_labels)
                loss.backward()
                optimizer.step()

                # print statistics
                running_trainloss += loss.item()

                train_acc += torch.sum(torch.argmax(batch_train_outputs, dim =1) ==  train_labels.long()) / torch.numel(train_labels)

                train_cnt += 1

            train_losses.append(running_trainloss / train_cnt)
            train_accs.append(train_acc / train_cnt )

        return net.state_dict(), sum(train_losses) / len(train_losses), sum(train_accs) / len(train_accs)

    def FedAvg(self, w_dict, w_prime_dict, coeff):
        for k in w_dict.keys():
            w_dict[k] = w_dict[k] + coeff * w_prime_dict[k]
        return w_dict

class FedProx_ClientUpdate(object):

    def __init__(self, train_loader):

        self.train_loader =  train_loader

        # Define Loss function
        self.criterion = nn.CrossEntropyLoss()

    def loss_function(self, net, w_t,batch_train_outputs, train_labels, mu):
        w_dict = net.state_dict()
        loss1 = self.criterion(batch_train_outputs, train_labels)

        loss2 = 0
        for l in w_dict.keys():
            loss2 += torch.norm(w_dict[l].float() - w_t[l].float(), p = 2)

        return loss1 + (mu /2) * loss2

    def update(self, net, max_epochs, lr, mu):
        optimizer = optim.SGD(net.parameters(), lr = lr, momentum = 0.9)
        w_t = copy.deepcopy(net.state_dict())

        train_losses = []
        validation_losses = []

        train_accs = []
        val_accs = []

        # Loop over epochs
        for epoch in range(max_epochs):

            running_trainloss,  train_acc = 0.0, 0.0
            train_cnt = 0

            # Training
            for train_data, train_labels in self.train_loader:
                # Transfer to GPU
                train_data, train_labels = train_data.to(device), train_labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # Model computations
                # forward + backward + optimize
                batch_train_outputs = net(train_data)

                loss = self.loss_function(net, w_t, batch_train_outputs, train_labels, mu)
                loss.backward()
                optimizer.step()

                # print statistics
                running_trainloss += loss.item()

                train_acc += torch.sum(torch.argmax(batch_train_outputs, dim =1) ==  train_labels.long()) / torch.numel(train_labels)

                train_cnt += 1

            train_losses.append(running_trainloss / train_cnt)
            train_accs.append(train_acc / train_cnt )

        return net.state_dict(), sum(train_losses) / len(train_losses), sum(train_accs)/len(train_accs)

    def FedAvg(self, w_dict, w_prime_dict, coeff):
        for k in w_dict.keys():
            w_dict[k] = w_dict[k] + coeff * w_prime_dict[k]
        return w_dict

"""FedAvg Training (Code Credit: https://github.com/erfunmirzaei/FedAvg-FedProx)"""

# Model
Net6 =  Deep_CNN(10)
Net6.to(device)

# Define Loss function and Optimizer
criterion6 = nn.CrossEntropyLoss()

train_losses6 = []
validation_losses6 = []

train_accs6 = []
val_accs6 = []

max_rounds = 5
C = 0.3
E = 5
n=12
learning_rate = 0.01
w_dict = copy.deepcopy(Net6.state_dict())

# Loop over epochs
for round in tqdm(range(max_rounds)):
    running_trainloss, running_valloss, train_acc = 0.0, 0.0, 0.0
    train_cnt, val_cnt = 0, 0
    clients = random.sample(list(np.arange(n)), k = int(C * n))
    coeffs = [ 1 / len(clients) for i in range(n)]

    for k in w_dict.keys():
        w_dict[k] = torch.zeros_like(w_dict[k])

    # Training
    for i in clients:
        cl_upd = FedAvg_ClientUpdate(trainloaders2[i])
        w_prime_dict, train_loss, acc = cl_upd.update(copy.deepcopy(Net6).to(device), E, learning_rate)
        w_dict = cl_upd.FedAvg(w_dict, w_prime_dict, coeffs[i])
        print("Cost of Train data for client %i in round %i for %i epochs: %f" %(i+1, round + 1, E, train_loss))
        running_trainloss += train_loss
        train_acc += acc
        train_cnt += 1

    Net6.load_state_dict(w_dict)

"""Evaluation"""

val_acc = 0
with torch.set_grad_enabled(False):
  for val_data, val_labels in testloader2:
              # Transfer to GPU
    val_data, val_labels = val_data.to(device), val_labels.to(device)

    # Model computations
    # forward + backward + optimize
    batch_outputs = Net6(val_data)

    val_loss = criterion5(batch_outputs, val_labels.long())

    running_valloss += val_loss.item()

    val_acc += torch.sum(torch.argmax(batch_outputs, dim =1) ==  val_labels.long()) / torch.numel(val_labels)

    val_cnt += 1
  print("Acc of Validation data for FedAvg in round %i: %f" %(round + 1, val_acc / val_cnt * 100))

"""FedProx Training (Code Credit: https://github.com/erfunmirzaei/FedAvg-FedProx)"""

# Model
Net5 =  Deep_CNN(10)
Net5.to(device)

# Define Loss function and Optimizer
criterion5 = nn.CrossEntropyLoss()

train_losses5 = []
validation_losses5 = []

train_accs5 = []
val_accs5 = []

max_rounds = 5
C = 0.3
E = 5
mu = 3
n = 12
learning_rate = 0.01
w_dict = copy.deepcopy(Net5.state_dict())

# Loop over epochs
for round in tqdm(range(max_rounds)):
    running_trainloss, running_valloss, train_acc = 0.0, 0.0, 0.0
    train_cnt, val_cnt = 0, 0
    clients = random.sample(list(np.arange(n)), k = int(C * n))
    coeffs = [ 1 / len(clients) for i in range(n)]

    for k in w_dict.keys():
        w_dict[k] = torch.zeros_like(w_dict[k])

    # Training
    for i in clients:
        cl_upd = FedProx_ClientUpdate(trainloaders2[i])
        w_prime_dict, train_loss, acc = cl_upd.update(copy.deepcopy(Net5).to(device), E, learning_rate, mu)
        w_dict = cl_upd.FedAvg(w_dict, w_prime_dict, coeffs[i])
        print("Cost of Train data for client %i in round %i for %i epochs: %f" %(i+1, round + 1, E, train_loss))
        running_trainloss += train_loss
        train_acc += acc
        train_cnt += 1

    Net5.load_state_dict(w_dict)

"""Evaluation"""

val_acc = 0
with torch.set_grad_enabled(False):
  for val_data, val_labels in testloader2:
              # Transfer to GPU
    val_data, val_labels = val_data.to(device), val_labels.to(device)

    # Model computations
    # forward + backward + optimize
    batch_outputs = Net5(val_data)

    val_loss = criterion5(batch_outputs, val_labels.long())

    running_valloss += val_loss.item()

    val_acc += torch.sum(torch.argmax(batch_outputs, dim =1) ==  val_labels.long()) / torch.numel(val_labels)

    val_cnt += 1
  print("Acc of Validation data for FedProx in round %i: %f" %(round + 1, val_acc / val_cnt * 100))