from fedlearn.algorithm.FedBase import Server as BasicServer
from fedlearn.models.CGFFedclient import CGFFedClient
from fedlearn.utils.metric import Metrics
from tqdm import tqdm
import time
import numpy as np
import torch
import copy

class Server(BasicServer):
    def __init__(self, dataset, options, name=''):
        super().__init__(dataset, options, name)
        self.clients = self.setup_clients(CGFFedClient, dataset, options)
        assert len(self.clients) > 0
        print('>>> Initialize {} clients in total'.format(len(self.clients)))

        # Initialize system metrics
        self.name = '_'.join([name, f'wn{self.clients_per_round}', f'tn{len(self.clients)}'])
        self.metrics = Metrics(self.clients, options)
        self.aggr = options['aggr']
        self.beta = options['beta']
        self.simple_average = options['simple_average']

        # CGF params
        self.fairness_constraints = options['fairness_constraints']
        self.fair_metric = self.fairness_constraints['fairness_measure']
        self.dual_params_lr = self.options['dual_params_lr']
        self.dual_params_bound = self.options['dual_params_bound']
        self.global_fair_constraint = self.fairness_constraints['global']
        self.local_fair_constraint = self.fairness_constraints['local']
        self.pretrain_round = self.options['pretrain_round']
        self.Alabel = [int(a) for a in self.data_info['A_info'].keys()]
        self.Ylabel = [int(y) for y in self.data_info['Y_info'].keys()]
        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                self.global_lamb = torch.tensor([0.1, 0.1], requires_grad=False)
            elif self.fair_metric == 'EO':
                self.global_lamb = torch.tensor([[0.1, 0.1],
                                                 [0.1, 0.1]], requires_grad=False) # first row for Y=1, second for Y=0
                # self.global_lamb[0] = torch.tensor([0.1, 0.1], requires_grad=False)
                # self.global_lamb[1] = torch.tensor([0.1, 0.1], requires_grad=False)
            else:
                numclass = 'multi-class' if self.options['fairness_type'] == 'subgroup' else 'binary'
                raise ValueError(f'Not support {self.fair_metric} in {numclass} classification!')
        elif self.options['fairness_type'] == 'subgroup':
            if self.fair_metric == 'DP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
            elif self.fair_metric == 'EOP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
            # if self.fair_metric == 'EOP':
            #     self.global_lamb = torch.tensor([0.1, 0.1], requires_grad=False)
            else:
                numclass = 'multi-class' if self.options['fairness_type'] == 'subgroup' else 'binary'
                raise ValueError(f'Not support {self.fair_metric} in {numclass} classification!')
        
    def train(self):
        print('>>> Select {} clients for aggregation per round \n'.format(self.clients_per_round))
        if self.gpu:
            self.latest_model = self.latest_model.to(self.device)

        # for self.current_round in tqdm(range(self.pretrain_round)):
        #     tqdm.write('>>> Round {}, latest model.norm = {}'.format(self.current_round, self.latest_model.norm()))

        #     # Test latest model on train and eval data
        #     if self.current_round % 5 == 0:
        #         stats = self.test(self.clients, self.current_round)
        #     self.metrics.update_model_stats(self.current_round, stats)

        #     self.iterate()
        #     self.current_round += 1

        for self.current_round in tqdm(range(self.num_round)):
            tqdm.write('>>> Round {}, latest model.norm = {}'.format(self.current_round, self.latest_model.norm()))

            # Test latest model on train and eval data
            if self.current_round % 2 == 0:
                stats = self.test(self.clients, self.current_round)
            self.metrics.update_model_stats(self.current_round, stats)

            self.iterate_CFG()
            self.current_round += 1

        # Test final model on train data
        self.stats = self.test(self.clients, self.current_round)

        self.metrics.update_model_stats(self.num_round, self.stats)

        # Save tracked information
        self.metrics.write()

    def iterate_CFG(self):
        selected_clients = self.select_clients(self.clients, self.current_round)

        # Do local update for the selected clients
        solns, stats = [], []
        local_global_lambda_info = []
        for c in selected_clients:
            # Communicate the latest global model
            c.set_params(self.latest_model)
            if self.options['fairness_type'] == 'groupwise':
                c.global_lamb = self.global_lamb
            elif self.options['fairness_type'] == 'subgroup':
                c.global_lamb_1 = self.global_lamb_1
                c.global_lamb_2 = self.global_lamb_2

            # Solve local and personal minimization
            soln, stat = c.CGF_train()
            
            if self.print:
                tqdm.write('>>> Round: {: >2d} local acc | CID:{}| loss {:>.4f} | Acc {:>5.2f}% | Time: {:>.2f}s'.format(
                    self.current_round, c.cid, stat['loss'], stat['acc'] * 100, stat['time']
                    ))
                        
            # Add solutions and stats
            solns.append(soln)
            stats.append(stat)

            if hasattr(c, 'local_fair_confusion_matrix'):
                local_global_lambda_info.append(c.local_fair_confusion_matrix)
    
        self.latest_model = self.aggregate(solns, seed = self.current_round, stats = stats)
        self.global_lamb_update(local_global_lambda_info)
        return True
    
    def global_lamb_update(self, local_global_lambda_info):
        if self.options['fairness_type'] == 'groupwise':
            assert len(list(self.data_info['A_info'].keys())) == 2
            assert len(list(self.data_info['Y_info'].keys())) == 2
            if self.fair_metric == "DP" :
                confusion_matrix = local_global_lambda_info
                total_pred_Y1_A1 = 0
                total_pred_Y1_A0 = 0
                total_A1 = 0
                total_A0 = 0
                for i in range(len(confusion_matrix)):
                    total_pred_Y1_A1 += sum(v for y_dict in confusion_matrix[i][1].values() for p, v in y_dict.items() if p == 1)
                    total_pred_Y1_A0 += sum(v for y_dict in confusion_matrix[i][0].values() for p, v in y_dict.items() if p == 1)
                    total_A1 += sum(v for y_dict in confusion_matrix[i][1].values() for v in y_dict.values())
                    total_A0 += sum(v for y_dict in confusion_matrix[i][0].values() for v in y_dict.values())
                global_fair_measure = total_pred_Y1_A1 / total_A1 - total_pred_Y1_A0/ total_A0
                self.global_lamb[0] += self.dual_params_lr * (global_fair_measure - self.global_fair_constraint)
                self.global_lamb[1] += self.dual_params_lr * (-global_fair_measure - self.global_fair_constraint)
                self.global_lamb = torch.clamp(self.global_lamb, min=0.0, max=self.dual_params_bound)

                # print(f'global_lambda: {self.global_lamb}, global fair value: {global_fair_measure}')
            elif self.fair_metric == "EO" :
                confusion_matrix = local_global_lambda_info
                total_pred_Y1_A1_Y1 = 0
                total_pred_Y1_A0_Y1 = 0
                total_pred_Y1_A1_Y0 = 0
                total_pred_Y1_A0_Y0 = 0
                total_A1_Y1 = 0
                total_A0_Y1 = 0
                total_A1_Y0 = 0
                total_A0_Y0 = 0
                for i in range(len(confusion_matrix)):
                    total_pred_Y1_A1_Y1 += confusion_matrix[i][1][1][1]
                    total_pred_Y1_A0_Y1 += confusion_matrix[i][0][1][1]
                    total_pred_Y1_A1_Y0 += confusion_matrix[i][1][0][1]
                    total_pred_Y1_A0_Y0 += confusion_matrix[i][0][0][1]
                    total_A1_Y1 += sum(confusion_matrix[i][1][1].values())
                    total_A0_Y1 += sum(confusion_matrix[i][0][1].values())
                    total_A1_Y0 += sum(confusion_matrix[i][1][0].values())
                    total_A0_Y0 += sum(confusion_matrix[i][0][0].values())
                global_fair_measure_Y1 = total_pred_Y1_A1_Y1 / total_A1_Y1 - total_pred_Y1_A0_Y1 / total_A0_Y1
                global_fair_measure_Y0 = total_pred_Y1_A1_Y0 / total_A1_Y0 - total_pred_Y1_A0_Y0 / total_A0_Y0
                self.global_lamb[0][0] += self.dual_params_lr * (global_fair_measure_Y1 - self.global_fair_constraint)
                self.global_lamb[0][1] += self.dual_params_lr * (-global_fair_measure_Y1 - self.global_fair_constraint)
                self.global_lamb[1][0] += self.dual_params_lr * (global_fair_measure_Y0 - self.global_fair_constraint)
                self.global_lamb[1][1] += self.dual_params_lr * (-global_fair_measure_Y0 - self.global_fair_constraint)
                self.global_lamb = torch.clamp(self.global_lamb, min=0.0, max=self.dual_params_bound)
                # self.global_lamb[1] = torch.clamp(self.global_lamb[1], min=0.0, max=self.dual_params_bound)

                # print(f'global_lambda: {self.global_lamb}')
                # print(f'global fair value_Y1: {global_fair_measure_Y1}, global fair value_Y0: {global_fair_measure_Y0}')
        elif self.options['fairness_type'] == 'subgroup':
            if self.fair_metric == "DP" :
                confusion_matrix = local_global_lambda_info
                num_samples = 0
                total_pred_Y = {y: 0 for y in self.Ylabel} #n_py
                total_pred_Y_A = {a:{y: 0 for y in self.Ylabel} for a in self.Alabel} #n_{a,py}
                total_A = {a: 0 for a in self.Alabel} #n_a
                for ii in range(len(confusion_matrix)):
                    for y in self.Ylabel:
                        total_pred_Y[y] += sum(v for d1 in confusion_matrix[ii].values() for d2 in d1.values() for p, v in d2.items() if p == y)
                    for aa in self.Alabel:
                        total_A[aa] += sum(v for y_dict in confusion_matrix[ii][aa].values() for v in y_dict.values())
                        for yy in self.Ylabel:
                            total_pred_Y_A[aa][yy] += sum(v for y_dict in confusion_matrix[ii][aa].values() for p, v in y_dict.items() if p == yy)
                            for pred_y in self.Ylabel:
                                num_samples += confusion_matrix[ii][aa][yy][pred_y]
                global_DP_org = {a: {y: total_pred_Y_A[a][y] / total_A[a] - total_pred_Y[y]/num_samples for y in self.Ylabel} for a in self.Alabel }
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        self.global_lamb_1[a_prime][y] += self.dual_params_lr * (global_DP_org[a_prime][y]- self.global_fair_constraint)
                        self.global_lamb_2[a_prime][y] += self.dual_params_lr * (- global_DP_org[a_prime][y]- self.global_fair_constraint)
                self.global_lamb_1 = torch.clamp(self.global_lamb_1, min=0.0, max=self.dual_params_bound)
                self.global_lamb_2 = torch.clamp(self.global_lamb_2, min=0.0, max=self.dual_params_bound)
            if self.fair_metric == "EOP" :
                confusion_matrix = local_global_lambda_info
                num_samples = 0
                total_Y = {y:0 for y in self.Ylabel} # n_y
                total_Y_pred_Y = {y: {pred_y:0 for pred_y in self.Ylabel} for y in self.Ylabel} #n_y_py
                total_Y_pred_Y_A = {a:{y: {pred_y:0 for pred_y in self.Ylabel} for y in self.Ylabel}for a in self.Alabel} #n_a_y_py
                total_Y_A = {a:{y: 0 for y in self.Ylabel} for a in self.Alabel} #n_{a,y}
                for i in range(len(confusion_matrix)):
                    for y in self.Ylabel:
                        total_Y[y] += sum(confusion_matrix[i][a_temp][y][pred_y] for a_temp in self.Alabel for pred_y in self.Ylabel)
                        for a in self.Alabel:
                            total_Y_A[a][y] += sum([confusion_matrix[i][a][y][y_temp] for y_temp in self.Ylabel])
                        for pred_y in self.Ylabel:
                            total_Y_pred_Y[y][pred_y] += sum([confusion_matrix[i][a_temp][y][pred_y] for a_temp in self.Alabel])
                            for aa in self.Alabel:
                                total_Y_pred_Y_A[aa][y][pred_y] += confusion_matrix[i][aa][y][pred_y]
                global_EOP_org = {a: {y: total_Y_pred_Y_A[a][y][y] / total_Y_A[a][y] - total_Y_pred_Y[y][y]/total_Y[y] for y in self.Ylabel} for a in self.Alabel }
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        self.global_lamb_1[a_prime][y] += self.dual_params_lr * (global_EOP_org[a_prime][y]- self.global_fair_constraint)
                        self.global_lamb_2[a_prime][y] += self.dual_params_lr * (- global_EOP_org[a_prime][y]- self.global_fair_constraint)
                self.global_lamb_1 = torch.clamp(self.global_lamb_1, min=0.0, max=self.dual_params_bound)
                self.global_lamb_2 = torch.clamp(self.global_lamb_2, min=0.0, max=self.dual_params_bound)
                

    def test(self, clients, current_round, ensemble=False):
        begin_time = time.time()
        client_stats, ids = self.local_test(clients, ensemble)
        end_time = time.time()

        train_acc = sum(client_stats['train']['accs']) / sum(client_stats['train']['num_samples'])
        train_loss = sum(client_stats['train']['losses']) / sum(client_stats['train']['num_samples'])
        train_fairness = self.global_fairness(client_stats['train'], 'train')
        
        test_acc = sum(client_stats['test']['accs']) / sum(client_stats['test']['num_samples'])
        test_loss = sum(client_stats['test']['losses']) / sum(client_stats['test']['num_samples'])
        test_fairness = self.global_fairness(client_stats['test'], 'test')

        model_stats = {'train_loss':train_loss, 'train_acc':train_acc, 'train_fairness':train_fairness,
                        'test_acc':test_acc, 'test_loss':test_loss, 'test_fairness':test_fairness,
                        'time': end_time - begin_time}
        
        # if self.print:
        #     self.print_result(current_round, model_stats, client_stats)
        self.print_result(current_round, model_stats, client_stats)

        return model_stats
    
    def local_test(self, clients, ensemble=False):
        assert self.latest_model is not None

        stats_dict = {'num_samples':[], 'accs':[], 'losses':[],'local_fair_measure':[],'local_confusion':[]}
        client_stats= {'train':copy.deepcopy(stats_dict), 
                       'test' :copy.deepcopy(stats_dict)}

        for c in clients:
            if self.test_local:
                c.set_params(c.local_params)
            else:
                c.set_params(self.latest_model)
            client_test_dict = c.local_eval(ensemble)

            for dataset, test_dict in client_test_dict.items():
                # dataset is train and test
                client_stats[dataset]['num_samples'].append(test_dict['num'])
                client_stats[dataset]['accs'].append(test_dict['acc'])
                client_stats[dataset]['losses'].append(test_dict['loss'])
                client_stats[dataset]['local_fair_measure'].append(test_dict['local_fair_measure'])
                client_stats[dataset]['local_confusion'].append(test_dict['local_confusion'])

        ids = [c.cid for c in clients]

        return client_stats, ids
    
    def iterate(self):
        selected_clients = self.select_clients(self.clients, self.current_round)

        # Do local update for the selected clients
        solns, stats = [], []
        for c in selected_clients:
            # Communicate the latest global model
            c.set_params(self.latest_model)

            # Solve local and personal minimization
            soln, stat = c.local_train()
            
            if self.print:
                tqdm.write('>>> Round: {: >2d} local acc | CID:{}| loss {:>.4f} | Acc {:>5.2f}% | Time: {:>.2f}s'.format(
                    self.current_round, c.cid, stat['loss'], stat['acc'] * 100, stat['time']
                    ))
                        
            # Add solutions and stats
            solns.append(soln)
            stats.append(stat)
    
        self.latest_model = self.aggregate(solns, seed = self.current_round, stats = stats)
        return True

    def aggregate(self, solns, seed, stats):
        averaged_solution = torch.zeros_like(self.latest_model)

        num_samples, chosen_solns = [info[0] for info in solns], [info[1] for info in solns]
        if self.aggr == 'mean':  
            if self.simple_average:
                num = 0
                for num_sample, local_soln in zip(num_samples, chosen_solns):
                    num += 1
                    averaged_solution += local_soln
                averaged_solution /= num
            else:      
                selected_sample = 0
                for num_sample, local_soln in zip(num_samples, chosen_solns):
                    # print(num_sample)
                    averaged_solution += num_sample * local_soln
                    selected_sample += num_sample
                    # print("local_soln:{},num_sample:{}".format(local_soln, num_sample))
                averaged_solution = averaged_solution / selected_sample
                # print(averaged_solution)

        elif self.aggr == 'median':
            stack_solution = torch.stack(chosen_solns)
            averaged_solution = torch.median(stack_solution, dim = 0)[0]
        elif self.aggr == 'krum':
            f = int(len(chosen_solns) * 0)
            dists = torch.zeros(len(chosen_solns), len(chosen_solns))
            scores = torch.zeros(len(chosen_solns))
            for i in range(len(chosen_solns)):
                for j in range(i, len(chosen_solns)):
                    dists[i][j] = torch.norm(chosen_solns[i] - chosen_solns[j], p = 2)
                    dists[j][i] = dists[i][j]
            for i in range(len(chosen_solns)):
                d = dists[i]
                d, _ = d.sort()
                scores[i] = d[:len(chosen_solns) - f - 1].sum()
            averaged_solution = chosen_solns[torch.argmin(scores).item()]
                
        averaged_solution = (1 - self.beta) * self.latest_model + self.beta * averaged_solution
        return averaged_solution.detach()
    