#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt
from ucimlrepo import fetch_ucirepo 
from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader, TensorDataset
import itertools


# In[ ]:


# NN model

class LogisticRegressionNN(nn.Module):
    def __init__(self, input_size):
        super(LogisticRegressionNN, self).__init__()
        self.fc = nn.Linear(input_size, 1)

    def forward(self, x):
        x = self.fc(x)
        return torch.sigmoid(x)

# calculating the loss under the classififer

def loss_classifier(predictions, labels):
    loss = nn.BCELoss()
    labels = labels
    
    return loss(predictions, labels)


# calculating the loss for the dataset
def loss_dataset(model, dataset, loss_f):
    """Compute the loss of `model` on `dataset`"""
    loss=0
    
    for idx,(features,_,labels) in enumerate(dataset):
        
        predictions= model(features)
        loss+=loss_f(predictions,labels)
    
    loss/=idx+1
    return loss

# calculating the accuracy for the dataset
def accuracy_dataset(model, dataset):
    """Compute the accuracy of `model` on `dataset`"""
    
    correct=0
    
    for features,_,labels in iter(dataset):
        
        predictions= model(features)
        
        predicted = predictions.round()
        correct+=torch.sum(predicted.view(-1,1)==labels.view(-1, 1)).item()
        
    accuracy = 100*correct/len(dataset.dataset)
        
    return accuracy

# calculating the fairness for the dataset
def fairness_dataset(model, dataset):
    """Compute the fairness of `model` on `dataset`"""
    
    target_list = []
    s_list = []
    x_list = []
    pred_list = []
    
    for features, s, labels in iter(dataset):
        
        predictions= model(features)
        
        predicted = predictions.round()
        x_list.append(features)
        s_list.append(s)
        target_list.append(labels)
        pred_list.append(predicted.detach().numpy())
        

    ppr_list = []
    tnr_list = []
    tpr_list = []
    pred_list = np.concatenate(pred_list).ravel()
    target_list = np.concatenate(target_list).ravel()
    pred_acc = ( pred_list==target_list)
    s_list = np.concatenate(s_list).ravel()
    
    for s_value in np.unique(s_list):
        if np.mean(s_list == s_value) > 0.01:
            indexs0  = np.logical_and(target_list==0, s_list==s_value)
            indexs1  = np.logical_and(target_list==1, s_list==s_value)
            ppr_list.append(np.mean(pred_list[s_list==s_value]))
            tnr_list.append(np.mean(pred_acc[indexs0]))
            tpr_list.append(np.mean(pred_acc[indexs1]))
            
    
    dp_gap = max(ppr_list) - min(ppr_list)
    eo_gap = max(tpr_list)-min(tpr_list)
        
    return dp_gap, eo_gap



# train the algorithm, output the avg. total loss
def train_step(model, model_0, mu:int, optimizer, train_data, loss_f):
    """Train `model` on one epoch of `train_data`"""
    
    total_loss=0
    target_list = []
    s_list = []
    x_list = []
    pred_list = []
    
    for idx, (features, s, labels) in enumerate(train_data):

        predictions= model(features)
        
        predicted = predictions.round()
        x_list.append(features)
        s_list.append(s)
        target_list.append(labels)
        pred_list.append(predicted.detach().numpy())

        loss=loss_f(predictions,labels)
        total_loss+=loss
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    ppr_list = []
    tnr_list = []
    tpr_list = []
    pred_list = np.concatenate(pred_list).ravel()
    target_list = np.concatenate(target_list).ravel()
    pred_acc = ( pred_list==target_list)
    s_list = np.concatenate(s_list).ravel()
    
    for s_value in np.unique(s_list):
        if np.mean(s_list == s_value) > 0.01:
            indexs0  = np.logical_and(target_list==0, s_list==s_value)
            indexs1  = np.logical_and(target_list==1, s_list==s_value)
            ppr_list.append(np.mean(pred_list[s_list==s_value]))
            tnr_list.append(np.mean(pred_acc[indexs0]))
            tpr_list.append(np.mean(pred_acc[indexs1]))
            
    
    dp_gap = max(ppr_list) - min(ppr_list)
    eo_gap = max(tpr_list)-min(tpr_list)
      
        
    return total_loss/(idx+1), dp_gap, eo_gap


# local learn
def local_learning(model, mu:float, optimizer, train_data, epochs:int, loss_f):
    
    model_0=deepcopy(model)
    
    for e in range(epochs):
        local_loss, DP_local_fairness, EO_local_fairness =train_step(model,model_0,mu,optimizer,train_data,loss_f)
        
    return float(local_loss.detach().numpy()), DP_local_fairness, EO_local_fairness

# reset model parameter
def set_to_zero_model_weights(model):
    """Set all the parameters of a model to 0"""

    for layer_weigths in model.parameters():
        layer_weigths.data.sub_(layer_weigths.data)

        
# aggregation
def average_models(model, clients_models_hist:list , weights:list):


    """Creates the new model of a given iteration with the models of the other
    clients"""
    
    new_model=deepcopy(model)
    set_to_zero_model_weights(new_model)

    for k,client_hist in enumerate(clients_models_hist):
        
        for idx, layer_weights in enumerate(new_model.parameters()):

            contribution=client_hist[idx].data*weights[k]
            layer_weights.data.add_(contribution)
            
    return new_model


# In[ ]:


def FedAvg(model, training_sets:list, n_iter:int, testing_sets:list, mu=0, 
    file_name="test", epochs = 50, lr=10**-2, decay=1):
    """ all the clients are considered in this implementation of FedProx
    Parameters:
        - `model`: common structure used by the clients and the server
        - `training_sets`: list of the training sets. At each index is the 
            training set of client "index"
        - `n_iter`: number of iterations the server will run
        - `testing_set`: list of the testing sets. If [], then the testing
            accuracy is not computed
        - `mu`: regularization term for FedProx. mu=0 for FedAvg
        - `epochs`: number of epochs each client is running
        - `lr`: learning rate of the optimizer
        - `decay`: to change the learning rate at each iteration
    
    returns :
        - `model`: the final global model 
    """
        
    loss_f=loss_classifier
    
    #Variables initialization
    K = len(training_sets) #number of clients
    n_samples=sum([len(db.dataset) for db in training_sets])
    weights=([len(db.dataset)/n_samples for db in training_sets])
    print("Clients' weights:",weights)
    
    
    loss_hist=[[float(loss_dataset(model, dl, loss_f).detach()) 
        for dl in training_sets]]
    acc_hist=[[accuracy_dataset(model, dl) for dl in testing_sets]]
    fairness_hist = [[fairness_dataset(model, dl) for dl in testing_sets]]
    server_hist=[[tens_param.detach().numpy() 
        for tens_param in list(model.parameters())]]
    models_hist = []
    
    
    server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])
    server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])
    DP_server_fairness=sum([weights[i]*fairness_hist[-1][i][0] for i in range(len(weights))])
    EO_server_fairness=sum([weights[i]*fairness_hist[-1][i][1] for i in range(len(weights))])
    
    print(f'====> i: 0 Loss: {server_loss} Server Test Accuracy: {server_acc} DP Server Fairness: {DP_server_fairness} EO Server Fairness: {EO_server_fairness}')
    
    for i in range(n_iter):
        
        clients_params=[]
        clients_models=[]
        clients_losses=[]
       
        for k in range(K):
            
        
            local_model=deepcopy(model)
            local_optimizer=optim.Adam(local_model.parameters(),lr=lr)
            
            local_loss, DP_local_fairness, EO_local_fairness =local_learning(local_model,mu,local_optimizer,
                training_sets[k],epochs,loss_f)
            
            clients_losses.append(local_loss)
                
            #GET THE PARAMETER TENSORS OF THE MODEL
            list_params=list(local_model.parameters())
            list_params=[tens_param.detach() for tens_param in list_params]
            clients_params.append(list_params)    
            clients_models.append(deepcopy(local_model))
            

        #CREATE THE NEW GLOBAL MODEL
        model = average_models(deepcopy(model), clients_params, 
            weights=weights)
        models_hist.append(clients_models)
        
        #COMPUTE THE LOSS/ACCURACY OF THE DIFFERENT CLIENTS WITH THE NEW MODEL
        loss_hist+=[[float(loss_dataset(model, dl, loss_f).detach()) 
            for dl in training_sets]]
        acc_hist+=[[accuracy_dataset(model, dl) for dl in testing_sets]]
        fairness_hist+=[[fairness_dataset(model, dl) for dl in testing_sets]]

        server_loss=sum([weights[i]*loss_hist[-1][i] for i in range(len(weights))])
        server_acc=sum([weights[i]*acc_hist[-1][i] for i in range(len(weights))])
        DP_server_fairness=sum([weights[i]*fairness_hist[-1][i][0] for i in range(len(weights))])
        EO_server_fairness=sum([weights[i]*fairness_hist[-1][i][1] for i in range(len(weights))])
    
        print(f'====> i: {i+1} Loss: {server_loss} Server Test Accuracy: {server_acc} DP Server Fairness: {DP_server_fairness} EO Server Fairness: {EO_server_fairness}')

        server_hist.append([tens_param.detach().cpu().numpy() 
            for tens_param in list(model.parameters())])
        
        #DECREASING THE LEARNING RATE AT EACH SERVER ITERATION
        lr*=decay
            
    return model, loss_hist, acc_hist, fairness_hist


# In[ ]:




