from tokenize import group
import torch, copy, time, random, warnings, os
import numpy as np

from torch.utils.data import DataLoader
from tqdm import tqdm
from .utils import *
from ray import tune
import torch.nn as nn

################## MODEL SETTING ########################
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
os.environ['KMP_DUPLICATE_LIB_OK']='True'
#########################################################

class Server(object):
    def __init__(self, model, dataset_info, args, seed = 123, num_workers = 4, ret = False, 
                train_prn = False, metric = "Demographic disparity", select_round = False,
                batch_size = 128, print_every = 1, fraction_clients = 1, Z = 2, prn = True, trial = False):
        """
        Server execution.

        Parameters
        ----------
        model: torch.nn.Module object.

        dataset_info: a list of three objects.
            - train_dataset: Dataset object.
            - test_dataset: Dataset object.
            - clients_idx: a list of lists, with each sublist contains the indexs of the training samples in one client.
                    the length of the list is the number of clients.

        seed: random seed.

        num_workers: number of workers.

        ret: boolean value. If true, return the accuracy and fairness measure and print nothing; else print the log and return None.

        train_prn: boolean value. If true, print the batch loss in local epochs.

        metric: three options, "Risk Difference", "pRule", "Demographic disparity".

        batch_size: a positive integer.

        print_every: a positive integer. eg. print_every = 1 -> print the information of that global round every 1 round.

        fraction_clients: float from 0 to 1. The fraction of clients chose to update the weights in each round.
        """

        self.model = model
        if torch.cuda.device_count()>1:
            self.model = nn.DataParallel(self.model)
        self.model.to(DEVICE)

        self.seed = seed
        self.num_workers = num_workers

        self.ret = ret
        self.prn = prn
        self.train_prn = False if ret else train_prn

        self.metric = 'Equal Opportunity Difference'
        self.disparity = equal_opportunity_difference
        

        if args.local_batch_size == None:
            self.batch_size = batch_size
        else:
            self.batch_size = args.local_batch_size
        self.print_every = print_every
        self.fraction_clients = fraction_clients

        self.train_dataset, self.test_dataset, self.clients_idx = dataset_info
        self.num_clients = len(self.clients_idx)
        self.Z = Z

        self.trial = trial
        self.select_round = select_round

        self.trainloader = self.train_val(self.train_dataset, batch_size)
   
    def train_val(self, dataset, batch_size, idxs_train_full = None, split = False):
        """
        Returns train, validation for a given local training dataset
        and user indexes.
        """
        torch.manual_seed(self.seed)
        
        # split indexes for train, validation (90, 10)
        if idxs_train_full == None: idxs_train_full = np.arange(len(dataset))
        idxs_train = idxs_train_full
       
        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                    batch_size=batch_size, shuffle=True)

        return trainloader
    
    def FedFB(self, args, bits = False):
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

        num_rounds = args.num_rounds
        local_epochs = args.local_epochs
        learning_rate = args.lr
        optimizer = args.optimizer
        alpha = args.alpha
        momentum = args.momentum
        # Training
        train_loss, train_accuracy = [], []
        start_time = time.time()
        weights = self.model.state_dict()
        if self.select_round: best_fairness = float('inf')

        # the number of samples whose label is y and sensitive attribute is z
        m_yz, lbd = {}, {}
        for y in [0,1]:
            for z in range(self.Z):
                m_yz[(y,z)] = ((self.train_dataset.y == y) & (self.train_dataset.sen == z)).sum()

        #the proportion of samples whose sensitive attribute is z
        for y in [0]:
            for z in range(self.Z):
                lbd[(y,z)] = (m_yz[(0,0)] + m_yz[(0,1)])/len(self.train_dataset)

        for y in [1]:
            for z in range(self.Z):
                lbd[(y,z)] = (m_yz[(1,0)] + m_yz[(1,1)])/len(self.train_dataset)/2



        # print("lbd", lbd[0,0], lbd[0,1], lbd[1,0], lbd[1,1])
        for round_ in tqdm(range(num_rounds)):
            local_weights, local_losses, nc = [], [], []
            if self.prn: print(f'\n | Global Training Round : {round_+1} |\n')

            self.model.train()

            for idx in range(self.num_clients):
                # load local model
                local_model = Client(dataset=self.train_dataset,
                                            idxs=self.clients_idx[idx], batch_size = self.batch_size, 
                                        option = "FB-Variant1", 
                                        seed = self.seed, prn = self.train_prn, Z = self.Z)
                # update local model
                #
                w, loss, nc_ = local_model.fb2_update(
                                model=copy.deepcopy(self.model), global_round=round_, 
                                    learning_rate = learning_rate, local_epochs = local_epochs, 
                                    optimizer = optimizer,momentum=momentum, m_yz = m_yz, lbd = lbd)
                nc.append(nc_)
                # print(len(self.train_dataset))
                # print(nc_) (27)
                local_weights.append(copy.deepcopy(w))
                local_losses.append(copy.deepcopy(loss))
            print("nc", nc)
            # update global weights
            weights = weighted_average_weights(local_weights, nc, sum(nc))
            #weights = weighted_average_weights(local_weights, [1 for i in range(self.num_clients)], self.num_clients)
            self.model.load_state_dict(weights)

            loss_avg = sum(local_losses) / len(local_losses)
            train_loss.append(loss_avg)

            # Calculate avg training accuracy over all clients at every round
            list_acc = []
            # the number of samples which are assigned to class y and belong to the sensitive group z
            n_yz, f_z = {}, {}
            for z in range(self.Z):
                for y in [0,1]:
                    n_yz[(y,z)] = 0

            for z in range(1, self.Z):
                f_z[z] = 0

            self.model.eval()
            for c in range(self.num_clients):
                local_model = Client(dataset=self.train_dataset,
                                            idxs=self.clients_idx[c], batch_size = self.batch_size, option = "FB-Variant1", 
                                            seed = self.seed, prn = self.train_prn, Z = self.Z)
                # # accuracy, 
                #                 loss, 
                #                 N(sensitive group, pos), 
                #                 N(non-sensitive group, pos), 
                #                 N(sensitive group),
                #                 N(non-sensitive group),
                #                 acc_loss,
                #                 fair_loss
                acc, loss, n_yz_c, acc_loss, fair_loss, f_z_c,m_yz_c = local_model.inference(model = self.model, train = True, bits = bits, truem_yz= m_yz) 
                list_acc.append(acc)
                
                for yz in n_yz:
                    n_yz[yz] += n_yz_c[yz]

                # for yz in n_yz_c:
                #     n_yz_c[yz] = n_yz_c[yz]/m_yz_c[yz]
                
                for z in range(1, self.Z):
                    f_z[z] += f_z_c[z] *nc[c]/sum(nc)
                    
                # if self.prn: print("Client %d: accuracy loss: %.2f | %s = %.2f" % (
                #     c+1, acc_loss, self.metric, self.disparity(n_yz_c)))
            train_accuracy.append((n_yz[(1,1)]+n_yz[(1,0)]+n_yz[(0,1)]+n_yz[(0,0)])/len(self.train_dataset))
            for yz in n_yz:
                n_yz[yz] = n_yz[yz]/m_yz[yz]
            #print('n_yz(0,0)', n_yz[(0,0)], 'n_yz(0,1)', n_yz[(0,1)], 'n_yz(1,0)', n_yz[(1,0)], 'n_yz(1,1)', n_yz[(1,1)])
            #print(n_yz[(1,0)]-n_yz[(1,1)])
            # mu_2= 0
            # mu_2 += sum([f_z[z] for z in range(1, self.Z)])**2
            # for z in range(1, self.Z):
            #     mu_2 += f_z[z]**2
            # mu_2 = mu_2**.5

            for y in [1]:
                if f_z[1] >= 0:
                    lbd[(y,1)] += alpha / (round_ + 1) ** .5 
                else:
                    lbd[(y,1)] -= alpha / (round_ + 1) ** .5
                #lbd[(y,1)] = lbd[(y,1)].item()
                lbd[(y,1)] = max(0, min(lbd[(y,1)], (m_yz[(1,0)]+m_yz[(1,1)])/len(self.train_dataset)))
                lbd[(y,0)] = (m_yz[(1,0)]+m_yz[(1,1)])/len(self.train_dataset) - lbd[(y,1)]

            if self.trial:
                with tune.checkpoint_dir(round_) as checkpoint_dir:
                    path = os.path.join(checkpoint_dir, "checkpoint")
                    torch.save(self.model.state_dict(), path)
                    
                tune.report(loss = loss, accuracy = train_accuracy[-1], disp = self.disparity(n_yz), iteration = round_+1)
            if self.select_round: 
                if best_fairness > self.disparity(n_yz): 
                    best_fairness = self.disparity(n_yz)
                    test_model = copy.deepcopy(self.model.state_dict())
            test_acc, n_yz,test_m_yz = self.test_inference(self.model, self.test_dataset)
        
            for yz in n_yz:
                n_yz[yz] = n_yz[yz]/test_m_yz[yz]
            
            rd = self.disparity(n_yz)

            print('test_acc', test_acc)
            print('fairness', rd)

        # Test inference after completion of training
        test_acc, n_yz,test_m_yz = self.test_inference(self.model, self.test_dataset)
        # print("test_m_yz", test_m_yz[0,0], test_m_yz[0,1], test_m_yz[1,0], test_m_yz[1,1])
        # print("test_n_yz", n_yz[0,0], n_yz[0,1], n_yz[1,0], n_yz[1,1])
        for yz in n_yz:
            n_yz[yz] = n_yz[yz]/test_m_yz[yz]
        
        rd = self.disparity(n_yz)

        if self.prn:
            print(f' \n Results after {num_rounds} global rounds of training:')
            # print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
            print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

            # Compute fairness metric
            print("|---- Test "+ self.metric+": {:.4f}".format(rd))

            print('\n Total Run Time: {0:0.4f} sec'.format(time.time()-start_time))

        if self.ret: return test_acc, rd, self.model


    def inference(self, option = 'unconstrained', penalty = 100, model = None, validloader = None):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group),
                                acc_loss,
                                fair_loss
        """

        if model == None: model = self.model
        if validloader == None: 
            validloader = self.validloader
        model.eval()
        loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
        n_yz, loss_yz = {}, {}
        for y in [0,1]:
            for z in range(self.Z):
                loss_yz[(y,z)] = 0
                n_yz[(y,z)] = 0
        
        for _, (features, labels, sensitive) in enumerate(validloader):
            features, labels = features.to(DEVICE), labels.type(torch.LongTensor).to(DEVICE)
            sensitive = sensitive.type(torch.LongTensor).to(DEVICE)
            
            # Inference
            outputs, logits = model(features)
            outputs, logits = outputs.to(DEVICE), logits.to(DEVICE)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1).to(DEVICE)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            num_batch += 1

            group_boolean_idx = {}
            
            for yz in n_yz:
                group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])
                n_yz[yz] += torch.sum((pred_labels == yz[0]) & (sensitive == yz[1])).item()     
                
                if option == "FairBatch":
                # the objective function have no lagrangian term

                    loss_yz_,_,_ = loss_func("FB_inference", logits[group_boolean_idx[yz]].to(DEVICE), 
                                                    labels[group_boolean_idx[yz]].to(DEVICE), 
                                         outputs[group_boolean_idx[yz]].to(DEVICE), sensitive[group_boolean_idx[yz]].to(DEVICE), 
                                         penalty)
                    loss_yz[yz] += loss_yz_
            
            batch_loss, batch_acc_loss, batch_fair_loss = loss_func(option, logits, 
                                                        labels, outputs, sensitive, penalty)
            loss, acc_loss, fair_loss = (loss + batch_loss.item(), 
                                         acc_loss + batch_acc_loss.item(), 
                                         fair_loss + batch_fair_loss.item())
        accuracy = correct/total
        if option in ["FairBatch", "FB-Variant1"]:
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, loss_yz
        else:
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, None

    def test_inference(self, model = None, test_dataset = None):

        """ 
        Returns the test accuracy and fairness level.
        """
        # set seed
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

        if model == None: model = self.model
        if test_dataset == None: test_dataset = self.test_dataset

        model.eval()
        total, correct = 0.0, 0.0
        n_yz = {}
        m_yz = {}
        for y in [0,1]:
            for z in range(self.Z):
                n_yz[(y,z)] = 0
                m_yz[(y,z)] = 0
        
        testloader = DataLoader(test_dataset, batch_size=self.batch_size,
                                shuffle=False)

        for _, (features, labels, sensitive) in enumerate(testloader):
            features = features.to(DEVICE)
            labels =  labels.to(DEVICE).type(torch.LongTensor)
            # Inference
            outputs, _ = model(features)

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            
            for y,z in n_yz:
                n_yz[(y,z)] += torch.sum((sensitive == z) & (pred_labels == y) & (labels == y)).item()  
                m_yz[(y,z)] += torch.sum((sensitive == z) & (labels == y)).item()

        accuracy = correct/total

        return accuracy, n_yz, m_yz

    def ufl_inference(self, models, test_dataset = None):
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

        if test_dataset == None: test_dataset = self.test_dataset

        total, correct = 0.0, 0.0
        n_yz = {}
        for y in [0,1]:
            for z in range(self.Z):
                n_yz[(y,z)] = 0

        for model in models:
            model.eval()

        testloader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
        for _, (features, labels, sensitive) in enumerate(testloader):
            features = features.to(DEVICE)
            labels =  labels.type(torch.LongTensor).to(DEVICE)
            sensitive = sensitive.to(DEVICE)

            # Inference
            outputs = torch.zeros((len(labels),2))
            for c in range(self.num_clients): 
                output, _ = models[c](features)
                output = output/output.sum()
                outputs += output * len(self.clients_idx[c])
            outputs = outputs / np.array(list(map(len, self.clients_idx))).sum()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1).to(DEVICE)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            
            for y,z in n_yz:
                n_yz[(y,z)] += torch.sum((sensitive == z) & (pred_labels == y)& (labels == y)).item()  
            

        accuracy = correct/total

        return accuracy, n_yz

class Client(object):
    def __init__(self, dataset, idxs, batch_size, option, seed = 0, prn = True, penalty = 500, Z = 2):
        self.seed = seed 
        self.dataset = dataset
        self.idxs = idxs
        self.option = option
        self.prn = prn
        self.Z = Z
        self.trainloader= self.train_val(dataset, list(idxs), batch_size)
        self.penalty = penalty
        self.disparity = DPDisparity

    def train_val(self, dataset, idxs, batch_size):
        """
        Returns train, validation for a given local training dataset
        and user indexes.
        """
        torch.manual_seed(self.seed)
        
        # split indexes for train, validation (90, 10)
        idxs_train = idxs

        self.train_dataset = DatasetSplit(dataset, idxs_train)

        trainloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)

        return trainloader

    def fb2_update(self, model, global_round, learning_rate, local_epochs, optimizer, momentum,lbd, m_yz):
        # Set mode to train model
        model.train()
        epoch_loss = []
        nc = 0

        # set seed
        np.random.seed(self.seed)
        random.seed(self.seed)
        torch.manual_seed(self.seed)

        # Set optimizer for the local updates
        if optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                                        momentum=momentum)
        elif optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                        weight_decay=1e-4)
        for i in range(local_epochs):
            batch_loss = []
            for batch_idx, (features, labels, sensitive) in enumerate(self.trainloader):
                features = features.to(DEVICE)
                labels=labels.to(DEVICE, dtype=torch.long)
                sensitive = sensitive.to(DEVICE)
                _, logits = model(features) #model(features)
                v = torch.ones(len(labels)).type(torch.DoubleTensor).to(DEVICE)
            
                
                group_idx = {}
                for y, z in lbd:
                    
                    group_idx[(y,z)] = torch.where((labels == y) & (sensitive == z))[0]
                    if len(group_idx[(y,z)]) != 0:
                        v[group_idx[(y,z)]] = lbd[(y,z)]/len(group_idx[(y,z)]) 

                    #nc +=num() /len(dataset_train)
                    nc += v[group_idx[(y,z)]].sum().item()
                    #loss weighting with v
                # print('nc', nc)
                # print('len(label)', len(labels))
                #print('group_idx', 'group_idx[(0,0)]:',len(group_idx[(0,0)]), 'group_idx[(0,1)]:',len(group_idx[(0,1)]), 'group_idx[(1,0)]:',len(group_idx[(1,0)]), 'group_idx[(1,1)]:',len(group_idx[(1,1)]))
                loss = weighted_loss(logits, labels, v, False)#mean by 1/dataset_train

                optimizer.zero_grad()
                if not np.isnan(loss.item()): loss.backward()
                optimizer.step()

                if self.prn and (100. * batch_idx / len(self.trainloader)) % 50 == 0:
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}'.format(
                        global_round + 1, i, batch_idx * len(features),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        # weight, loss
        return model.state_dict(), sum(epoch_loss) / len(epoch_loss), nc


    def inference(self, model, train = False, bits = False, truem_yz = None):
        """ 
        Returns the inference accuracy, 
                                loss, 
                                N(sensitive group, pos), 
                                N(non-sensitive group, pos), 
                                N(sensitive group),
                                N(non-sensitive group),
                                acc_loss,
                                fair_loss
        """

        model.eval()
        loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
        n_yz, loss_yz, m_yz, f_z = {}, {}, {}, {}
        for y in [0,1]:
            for z in range(self.Z):
                loss_yz[(y,z)] = 0
                n_yz[(y,z)] = 0
                m_yz[(y,z)] = 0

        #dataset = self.validloader if not train else self.trainloader
        dataset = self.trainloader
        for _, (features, labels, sensitive) in enumerate(dataset):
            features, labels = features.to(DEVICE), labels.type(torch.LongTensor).to(DEVICE)
            sensitive = sensitive.type(torch.LongTensor).to(DEVICE)
            
            # Inference
            outputs, logits = model(features)
            outputs, logits = outputs.to(DEVICE), logits.to(DEVICE)

            # Prediction
            
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1).to(DEVICE)
            bool_correct = torch.eq(pred_labels, labels)
            correct += torch.sum(bool_correct).item()
            total += len(labels)
            num_batch += 1

            group_boolean_idx = {}
            
            for yz in n_yz:
                group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])
                n_yz[yz] += torch.sum((pred_labels == yz[0]) & (sensitive == yz[1]) & (labels == yz[0])).item()     
                m_yz[yz] += torch.sum((labels == yz[0]) & (sensitive == yz[1])).item()    

                if self.option in["FairBatch", "FB-Variant1"]:
                # the objective function have no lagrangian term

                    loss_yz_,_,_ = loss_func("standard", logits[group_boolean_idx[yz]].to(DEVICE), 
                                                    labels[group_boolean_idx[yz]].to(DEVICE), 
                                         outputs[group_boolean_idx[yz]].to(DEVICE), sensitive[group_boolean_idx[yz]].to(DEVICE), 
                                         self.penalty)
                    loss_yz[yz] += loss_yz_
            
            batch_loss, batch_acc_loss, batch_fair_loss = loss_func(self.option, logits, 
                                                        labels, outputs, sensitive, self.penalty)
            loss, acc_loss, fair_loss = (loss + batch_loss.item(), 
                                         acc_loss + batch_acc_loss.item(), 
                                         fair_loss + batch_fair_loss.item())
            #print("loss", loss_yz[(1,1)]-loss_yz[(1,0)],loss_yz[(1,1)],loss_yz[(1,0)])
        # for yz in n_yz:
        #     if m_yz[yz] != 0:
        #         n_yz[yz] = n_yz[yz]/m_yz[yz]
        #     else:
        #         print("m_yz = 0")
        #         n_yz[yz] = 0
        accuracy = correct/total
        if self.option in ["FairBatch", "FB-Variant1"]:
            for z in range(1, self.Z):
                f_z[z] = loss_yz[(1,z)]/m_yz[(1,z)]-loss_yz[(1,0)]/m_yz[(1,0)]
            if bits: 
                bins = np.linspace(-2, 2, 2**bits // (self.Z - 1))
                for z in range(1, self.Z):
                    f_z[z] = bins[np.digitize(f_z[z].item(), bins)-1]
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, f_z, m_yz
        else:
            return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, m_yz
