from tokenize import group
import torch, copy, time, random, warnings, os
import numpy as np

from torch.utils.data import DataLoader
from tqdm import tqdm
from .utils import *
from ray import tune
import torch.nn as nn

################## MODEL SETTING ########################
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
os.environ['KMP_DUPLICATE_LIB_OK']='True'
#########################################################

class FairFed_Server(object):
    def __init__(self, model, dataset_info,args, seed = 123, num_workers = 4, ret = False, 
                train_prn = False, metric = 'Equal Opportunity Difference', select_round = False,
                batch_size = None, print_every = 1, fraction_clients = 1, Z = 2, prn = True, trial = False):
        """
        Server execution.

        Parameters
        ----------
        model: torch.nn.Module object.

        dataset_info: a list of three objects.
            - train_dataset: Dataset object.
            - test_dataset: Dataset object.
            - clients_idx: a list of lists, with each sublist contains the indexs of the training samples in one client.
                    the length of the list is the number of clients.

        seed: random seed.

        num_workers: number of workers.

        ret: boolean value. If true, return the accuracy and fairness measure and print nothing; else print the log and return None.

        train_prn: boolean value. If true, print the batch loss in local epochs.

        metric: three options, "Risk Difference", "pRule", "Demographic disparity".

        batch_size: a positive integer.

        print_every: a positive integer. eg. print_every = 1 -> print the information of that global round every 1 round.

        fraction_clients: float from 0 to 1. The fraction of clients chose to update the weights in each round.
        """

        self.model = model
        if torch.cuda.device_count()>1:
            self.model = nn.DataParallel(self.model)
        self.model.to(DEVICE)

        self.seed = seed
        self.num_workers = num_workers

        self.ret = ret
        self.prn = prn
        self.train_prn = False if ret else train_prn

        self.metric = 'Equal Opportunity Difference'
        self.disparity = equal_opportunity_difference
        # if metric == "Risk Difference":
        #     self.disparity = riskDifference
        # elif metric == "pRule":
        #     self.disparity = pRule
        # elif metric == "Demographic disparity":
        #     self.disparity = DPDisparity
        # # elif metric == "Equal Opportunity Difference":
        # #     self.disparity = equal_opportunity_difference
        # else:
        #     warnings.warn("Warning message: metric " + metric + " is not supported! Use the default metric Demographic disparity. ")
        #     self.disparity = DPDisparity
        #     self.metric = "Demographic disparity"
        if batch_size == None: self.batch_size = args.local_batch_size
        self.print_every = print_every
        self.fraction_clients = fraction_clients

        self.train_dataset, self.test_dataset, self.clients_idx = dataset_info
        self.num_clients = len(self.clients_idx)
        self.Z = Z

        self.trial = trial
        self.select_round = select_round

        self.trainloader = self.train_val(self.train_dataset, batch_size)
    
    def train_val(self, dataset, batch_size, idxs_train_full = None, split = False):
        """
        Returns train, validation for a given local training dataset
        and user indexes.
        """
        torch.manual_seed(self.seed)
        
        # split indexes for train, validation (90, 10)
        if idxs_train_full == None: idxs_train_full = np.arange(len(dataset))
        idxs_train = idxs_train_full
       
        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                    batch_size=batch_size, shuffle=True)

        return trainloader
   
    def FairFed(self, args, bits = False):
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

        num_rounds = args.num_rounds
        local_epochs = args.local_epochs
        learning_rate = args.lr
        optimizer = args.optimizer
        alpha = args.alpha
        momentum = args.momentum
        # Training
        train_loss, train_accuracy = [], []
        start_time = time.time()
        weights = self.model.state_dict()
        if self.select_round: best_fairness = float('inf')

        # the number of samples whose label is y and sensitive attribute is z
        total_yz = {}
        for y in [0,1]:
            for z in range(self.Z):
                total_yz[(y,z)] = ((self.train_dataset.y == y) & (self.train_dataset.sen == z)).sum()


        total=total_yz[(0,0)]+total_yz[(0,1)]+total_yz[(1,0)]+total_yz[(1,1)]
        # print("lbd", lbd[0,0], lbd[0,1], lbd[1,0], lbd[1,1])
        for round_ in tqdm(range(num_rounds)):
            local_weights, local_losses,local_F = [], [], []
            F_global = 0
            if self.prn: print(f'\n | Global Training Round : {round_+1} |\n')

            self.model.train()

            for idx in range(self.num_clients):
                # load local model
                local_model = Client(dataset=self.train_dataset,
                                            idxs=self.clients_idx[idx], batch_size = self.batch_size, 
                                        option = "FB-Variant1", 
                                        seed = self.seed, prn = self.train_prn, Z = self.Z)
                # update local model
                #
                w, loss, m, F = local_model.ff_update(
                                model=copy.deepcopy(self.model), global_round=round_, 
                                    learning_rate = learning_rate, local_epochs = local_epochs, 
                                    optimizer = optimizer,momentum=momentum, total_yz = total_yz)

                F_global += m
                local_F.append(F)
                local_weights.append(copy.deepcopy(w))
                local_losses.append(copy.deepcopy(loss))
                
            #F_global = F_global/total
            #print(F_global)
            client_weights=[]
            for idx in range(self.num_clients):
                weight=np.exp(-alpha*np.abs(local_F[idx]-F_global))*len(self.clients_idx[idx])/total
                client_weights.append(weight)
            client_weights=np.array(client_weights)
            client_weights=client_weights/np.sum(client_weights)
            #print(client_weights)

            weights = weighted_average_weights(local_weights, client_weights, 1) 
            self.model.load_state_dict(weights)

            loss_avg = sum(local_losses) / len(local_losses)
            train_loss.append(loss_avg)

            print('Round {:3d}', round_)


        # Test inference after completion of training
        test_acc, n_yz,test_m_yz = self.test_inference(self.model, self.test_dataset)
        for yz in n_yz:
            n_yz[yz] = n_yz[yz]/test_m_yz[yz]
        # print("test_acc", test_acc)
        # print('DEOO', self.disparity(n_yz))
        
        rd = self.disparity(n_yz)

        if self.prn:
            print(f' \n Results after {num_rounds} global rounds of training:')
            # print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
            print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

            # Compute fairness metric
            print("|---- Test "+ self.metric+": {:.4f}".format(rd))

            print('\n Total Run Time: {0:0.4f} sec'.format(time.time()-start_time))

        if self.ret: return test_acc, rd, self.model


    def inference(self, option = 'unconstrained', penalty = 100, model = None):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group),
                                acc_loss,
                                fair_loss
        """

        if model == None: model = self.model
         
        validloader = self.trainloader
        model.eval()
        loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
        n_yz, loss_yz = {}, {}
        for y in [0,1]:
            for z in range(self.Z):
                loss_yz[(y,z)] = 0
                n_yz[(y,z)] = 0
        
        for _, (features, labels, sensitive) in enumerate(validloader):
            features, labels = features.to(DEVICE), labels.type(torch.LongTensor).to(DEVICE)
            sensitive = sensitive.type(torch.LongTensor).to(DEVICE)
            
            # Inference
            outputs, logits = model(features)
            outputs, logits = outputs.to(DEVICE), logits.to(DEVICE)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1).to(DEVICE)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            num_batch += 1

            group_boolean_idx = {}
            
            for yz in n_yz:
                group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])
                n_yz[yz] += torch.sum((pred_labels == yz[0]) & (sensitive == yz[1])).item()     
                
                if option == "FairBatch":
                # the objective function have no lagrangian term

                    loss_yz_,_,_ = loss_func("FB_inference", logits[group_boolean_idx[yz]].to(DEVICE), 
                                                    labels[group_boolean_idx[yz]].to(DEVICE), 
                                         outputs[group_boolean_idx[yz]].to(DEVICE), sensitive[group_boolean_idx[yz]].to(DEVICE), 
                                         penalty)
                    loss_yz[yz] += loss_yz_
            
            batch_loss, batch_acc_loss, batch_fair_loss = loss_func(option, logits, 
                                                        labels, outputs, sensitive, penalty)
            loss, acc_loss, fair_loss = (loss + batch_loss.item(), 
                                         acc_loss + batch_acc_loss.item(), 
                                         fair_loss + batch_fair_loss.item())
        accuracy = correct/total
        if option in ["FairBatch", "FB-Variant1"]:
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, loss_yz
        else:
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, None

    def test_inference(self, model = None, test_dataset = None):

        """ 
        Returns the test accuracy and fairness level.
        """
        # set seed
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

        if model == None: model = self.model
        if test_dataset == None: test_dataset = self.test_dataset

        model.eval()
        total, correct = 0.0, 0.0
        n_yz = {}
        m_yz = {}
        for y in [0,1]:
            for z in range(self.Z):
                n_yz[(y,z)] = 0
                m_yz[(y,z)] = 0
        
        testloader = DataLoader(test_dataset, batch_size=self.batch_size,
                                shuffle=False)

        for _, (features, labels, sensitive) in enumerate(testloader):
            features = features.to(DEVICE)
            labels =  labels.to(DEVICE).type(torch.LongTensor)
            # Inference
            _, outputs = model(features)

            # Prediction
            pred_labels = torch.argmax(outputs, dim=1)
            pred_labels = pred_labels.view(-1).cpu()
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            
            for y,z in n_yz:
                n_yz[(y,z)] += torch.sum((sensitive == z) & (pred_labels == y) & (labels == y)).item()  
                m_yz[(y,z)] += torch.sum((sensitive == z) & (labels == y)).item()

        accuracy = correct/total
        print(correct, total)
        print(n_yz)

        return accuracy, n_yz, m_yz


class Client(object):
    def __init__(self, dataset, idxs, batch_size, option, seed = 0, prn = True, penalty = 500, Z = 2):
        self.seed = seed 
        self.dataset = dataset
        self.train_dataset=DatasetSplit(dataset,idxs)
        self.idxs = idxs
        self.option = option
        self.prn = prn
        self.Z = Z
        self.trainloader= self.train_val(dataset, list(idxs), batch_size)
        self.penalty = penalty

    def train_val(self, dataset, idxs, batch_size):
        """
        Returns train, validation for a given local training dataset
        and user indexes.
        """
        torch.manual_seed(self.seed)
        
        # split indexes for train, validation (90, 10)
        idxs_train = idxs

        self.train_dataset = DatasetSplit(dataset, idxs_train)

        trainloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)

        return trainloader

    def ff_update(self, model, global_round, learning_rate, local_epochs, optimizer, momentum, total_yz, beta=0.01):
        # Set mode to train model
        model.eval()
        loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
        n_yz, loss_yz, m_yz, f_z = {}, {}, {}, {}
        for y in [0,1]:
            for z in range(self.Z):
                loss_yz[(y,z)] = 0
                n_yz[(y,z)] = 0
                m_yz[(y,z)] = 0

        #dataset = self.validloader if not train else self.trainloader
        dataset = self.trainloader
        for _, (features, labels, sensitive) in enumerate(dataset):
            features, labels = features.to(DEVICE), labels.type(torch.LongTensor).to(DEVICE)
            sensitive = sensitive.type(torch.LongTensor).to(DEVICE)
            
            # Inference
            outputs, logits = model(features)
            outputs, logits = outputs.to(DEVICE), logits.to(DEVICE)

            # Prediction
            pred_labels = torch.argmax(outputs, dim=1)
            pred_labels = pred_labels.view(-1).to(DEVICE)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            num_batch += 1

            group_boolean_idx = {}
            
            for yz in n_yz:
                group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])
                n_yz[yz] += torch.sum((pred_labels == yz[0]) & (sensitive == yz[1]) & (labels == yz[0])).item()     
                m_yz[yz] += torch.sum((labels == yz[0]) & (sensitive == yz[1])).item()    

                if self.option in["FairBatch", "FB-Variant1"]:
                # the objective function have no lagrangian term

                    loss_yz_,_,_ = loss_func("standard", logits[group_boolean_idx[yz]].to(DEVICE), 
                                                    labels[group_boolean_idx[yz]].to(DEVICE), 
                                         outputs[group_boolean_idx[yz]].to(DEVICE), sensitive[group_boolean_idx[yz]].to(DEVICE), 
                                         self.penalty)
                    loss_yz[yz] += loss_yz_
            
            batch_loss, batch_acc_loss, batch_fair_loss = loss_func(self.option, logits, 
                                                        labels, outputs, sensitive, self.penalty)
            loss, acc_loss, fair_loss = (loss + batch_loss.item(), 
                                         acc_loss + batch_acc_loss.item(), 
                                         fair_loss + batch_fair_loss.item())
        F_local=0
        if m_yz[(1,0)] !=0 :
            F_local+=n_yz[(1,0)]/m_yz[(1,0)]
        if m_yz[(1,1)] !=0 :
            F_local-=n_yz[(1,1)]/m_yz[(1,1)]
        m_global=n_yz[(1,0)]/total_yz[(1,0)]-n_yz[(1,1)]/total_yz[(1,1)]

        model.train()
        epoch_loss = []

        # set seed
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)
        local_count_yz={}
        for y in [0,1]:
            for z in range(self.Z):
                local_count_yz[(y,z)]=((self.train_dataset.y == y) & (self.train_dataset.sen == z)).sum()
            
        lbd_yz={}
        for y in [0,1]:
            for z in range(self.Z):
                lbd_yz[(y,z)]=local_count_yz[(y,z)]/len(self.train_dataset)
        
        # Set optimizer for the local updates
        if optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                                        momentum=momentum)
        elif optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                        weight_decay=1e-4)
        for i in range(local_epochs):
            batch_loss = []
            for batch_idx, (features, labels, sensitive) in enumerate(self.trainloader):
                features = features.to(DEVICE)
                labels=labels.to(DEVICE, dtype=torch.long)
                sensitive = sensitive.to(DEVICE)
                _, logits = model(features) #model(features)
                v = torch.ones(len(labels)).type(torch.DoubleTensor).to(DEVICE)
                group_idx = {}
                for y, z in lbd_yz:
                    
                    group_idx[(y,z)] = torch.where((labels == y) & (sensitive == z))[0]
                    if len(group_idx[(y,z)]) != 0:
                        v[group_idx[(y,z)]] = lbd_yz[(y,z)]/len(group_idx[(y,z)]) 

                loss = weighted_loss(logits, labels, v, False)

                optimizer.zero_grad()
                if not np.isnan(loss.item()): loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

            model.eval()
            loss, total, correct, num_batch = 0.0, 0.0, 0.0, 0.0
            loss_yz = {}
            for y in [0,1]:
                for z in range(self.Z):
                    loss_yz[(y,z)] = 0



            dataset = self.trainloader
            for _, (features, labels, sensitive) in enumerate(dataset):
                features, labels = features.to(DEVICE), labels.type(torch.LongTensor).to(DEVICE)
                sensitive = sensitive.type(torch.LongTensor).to(DEVICE)
                
                # Inference
                outputs, logits = model(features)
                outputs, logits = outputs.to(DEVICE), logits.to(DEVICE)

                # Prediction
                pred_labels = torch.argmax(outputs, dim=1)
                pred_labels = pred_labels.view(-1).to(DEVICE)
                bool_correct = torch.eq(pred_labels, labels)
                correct += torch.sum(bool_correct).item()
                total += len(labels)
                num_batch += 1

                group_boolean_idx = {}
                
                for yz in n_yz:
                    group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])  

                    loss_yz_,_,_ = loss_func("standard", logits[group_boolean_idx[yz]].to(DEVICE), 
                                                    labels[group_boolean_idx[yz]].to(DEVICE), 
                                        outputs[group_boolean_idx[yz]].to(DEVICE), sensitive[group_boolean_idx[yz]].to(DEVICE), 
                                        self.penalty)
                    loss_yz[yz] += loss_yz_
            loss_yz[(1,0)] = loss_yz[(1,0)]/local_count_yz[(1,0)]
            loss_yz[(1,1)] = loss_yz[(1,1)]/local_count_yz[(1,1)]
            if loss_yz[(1,0)]-loss_yz[(1,1)]>0:
                lbd_yz[(1,0)]=lbd_yz[(1,0)]+beta
                if lbd_yz[(1,0)]> (local_count_yz[(1,0)]+local_count_yz[(1,1)])/len(self.train_dataset):
                    lbd_yz[(1,0)]=(local_count_yz[(1,0)]+local_count_yz[(1,1)])/len(self.train_dataset)
                lbd_yz[(1,1)]=(local_count_yz[(1,0)]+local_count_yz[(1,1)])/len(self.train_dataset)-lbd_yz[(1,0)]
            else:
                lbd_yz[(1,0)]=lbd_yz[(1,0)]-beta
                if lbd_yz[(1,0)]<0:
                    lbd_yz[(1,0)]=0
                lbd_yz[(1,1)]=(local_count_yz[(1,0)]+local_count_yz[(1,1)])/len(self.train_dataset)-lbd_yz[(1,0)]
                
            
        
        # weight, loss
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss), m_global,F_local


    def inference(self, model, train = False, bits = False, truem_yz = None):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group),
                                acc_loss,
                                fair_loss
        """

        model.eval()
        loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
        n_yz, loss_yz, m_yz, f_z = {}, {}, {}, {}
        for y in [0,1]:
            for z in range(self.Z):
                loss_yz[(y,z)] = 0
                n_yz[(y,z)] = 0
                m_yz[(y,z)] = 0


        dataset = self.trainloader
        for _, (features, labels, sensitive) in enumerate(dataset):
            features, labels = features.to(DEVICE), labels.type(torch.LongTensor).to(DEVICE)
            sensitive = sensitive.type(torch.LongTensor).to(DEVICE)
            
            # Inference
            outputs, logits = model(features)
            outputs, logits = outputs.to(DEVICE), logits.to(DEVICE)

            # Prediction
            
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1).to(DEVICE)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            num_batch += 1

            group_boolean_idx = {}
            
            for yz in n_yz:
                group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])
                n_yz[yz] += torch.sum((pred_labels == yz[0]) & (sensitive == yz[1]) & (labels == yz[0])).item()     
                m_yz[yz] += torch.sum((labels == yz[0]) & (sensitive == yz[1])).item()    

                if self.option in["FairBatch", "FB-Variant1"]:
                # the objective function have no lagrangian term

                    loss_yz_,_,_ = loss_func("standard", logits[group_boolean_idx[yz]].to(DEVICE), 
                                                    labels[group_boolean_idx[yz]].to(DEVICE), 
                                         outputs[group_boolean_idx[yz]].to(DEVICE), sensitive[group_boolean_idx[yz]].to(DEVICE), 
                                         self.penalty)
                    loss_yz[yz] += loss_yz_
            
            batch_loss, batch_acc_loss, batch_fair_loss = loss_func(self.option, logits, 
                                                        labels, outputs, sensitive, self.penalty)
            loss, acc_loss, fair_loss = (loss + batch_loss.item(), 
                                         acc_loss + batch_acc_loss.item(), 
                                         fair_loss + batch_fair_loss.item())
            #print("loss", loss_yz[(1,1)]-loss_yz[(1,0)],loss_yz[(1,1)],loss_yz[(1,0)])
        # for yz in n_yz:
        #     if m_yz[yz] != 0:
        #         n_yz[yz] = n_yz[yz]/m_yz[yz]
        #     else:
        #         print("m_yz = 0")
        #         n_yz[yz] = 0
        accuracy = correct/total
        if self.option in ["FairBatch", "FB-Variant1"]:
            for z in range(1, self.Z):
                f_z[z] = loss_yz[(1,z)]/m_yz[(1,z)]-loss_yz[(1,0)]/m_yz[(1,0)]
            if bits: 
                bins = np.linspace(-2, 2, 2**bits // (self.Z - 1))
                for z in range(1, self.Z):
                    f_z[z] = bins[np.digitize(f_z[z].item(), bins)-1]
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, f_z, m_yz
        else:
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, m_yz
