from fedlearn.algorithm.FedBase import Server as BasicServer
from fedlearn.models.CGFFedclient_Post import CGFFedClient
from fedlearn.utils.metric import Metrics
from tqdm import tqdm
import time
import numpy as np
import torch
import copy
import os

class Server(BasicServer):
    def __init__(self, dataset, options, name=''):
        super().__init__(dataset, options, name)
        self.clients = self.setup_clients(CGFFedClient, dataset, options)
        assert len(self.clients) > 0
        print('>>> Initialize {} clients in total'.format(len(self.clients)))

        # Initialize system metrics
        self.name = '_'.join([name, f'wn{self.clients_per_round}', f'tn{len(self.clients)}'])
        self.metrics = Metrics(self.clients, options)
        self.aggr = options['aggr']
        self.beta = options['beta']
        self.simple_average = options['simple_average']
        self.post_round = options['post_round']

        # CGF params
        self.fairness_constraints = options['fairness_constraints']
        self.fair_metric = self.fairness_constraints['fairness_measure']
        self.dual_params_lr = self.options['dual_params_lr']
        self.dual_params_bound = self.options['dual_params_bound']
        self.global_fair_constraint = self.fairness_constraints['global']
        self.local_fair_constraint = self.fairness_constraints['local']
        self.pretrain_round = self.options['pretrain_round']
        self.Alabel = [int(a) for a in self.data_info['A_info'].keys()]
        self.Ylabel = [int(y) for y in self.data_info['Y_info'].keys()]
        self.save_path = options['data_save_path']

        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                self.global_lamb = torch.tensor([0.01, 0.01], requires_grad=False)
            if self.fair_metric == 'EO':
                self.global_lamb = torch.tensor([[0.01, 0.01],
                                                 [0.01, 0.01]], requires_grad=False) # first row for Y=1, second for Y=0
        if self.options['fairness_type'] == 'subgroup':
            if self.fair_metric == 'DP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
            if self.fair_metric == 'EOP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
            # if self.fair_metric == 'EOP':
            #     self.global_lamb = torch.tensor([0.1, 0.1], requires_grad=False)

        self.data_exist = options.get('data_exist',False)
    
    
    def train(self):
        print('>>> Select {} clients for aggregation per round \n'.format(self.clients_per_round))

        pre_model_path = self.save_path + 'pre_model.pth'
        
        # load pre-trained model and calcualte bayes score
        if self.data_exist == True and os.path.exists(pre_model_path) and self.options['generate'] == False:
            self.latest_model = torch.load(pre_model_path)
        else:
            if os.path.exists(pre_model_path):
                os.remove(pre_model_path)

            if self.gpu:
                self.latest_model = self.latest_model.to(self.device)

            # model pre-train
            for self.current_round in tqdm(range(self.num_round)):
                tqdm.write('>>> Training Round {}, latest model.norm = {}'.format(self.current_round, self.latest_model.norm()))
                if self.current_round % 5 == 0:
                    stats = self.test(self.clients, self.current_round)
                self.metrics.update_model_stats(self.current_round, stats)
                self.iterate()
                self.current_round += 1
            self.test(self.clients, self.current_round)
            torch.save(self.latest_model, pre_model_path)

        # get sample score
        for c in self.clients:
            c.set_params(self.latest_model)
            c.get_val_test_score() # get score in self.val_score, self.test_score

        # post-processing
        for self.post_current_round in tqdm(range(self.post_round)):
            tqdm.write('>>> Post-processing Round {}'.format(self.post_current_round))

            # Test latest model on train and eval data
            if self.post_current_round % 1 == 0:
                stats = self.post_test(self.clients, self.post_current_round)
            # self.metrics.update_model_stats(self.current_round, stats)

            self.iterate_post()
            self.post_current_round += 1

        # Test final model on train data
        self.stats = self.post_test(self.clients, self.post_current_round)

        # print('Unified model test:')
        # unified_stats = self.test(self.clients, self.current_round)


    def iterate_post(self):
        selected_clients = self.select_clients(self.clients, self.post_current_round)

        # Do local update for the selected clients
        if self.options['fairness_type'] == 'groupwise':
            lamb_optim = []
            user_val_num = []
            for c in selected_clients:
            # Communicate the latest global model
                c.global_lamb = self.global_lamb
                local_lamb_optim, local_data_num = c.FFACT_post_train()
                lamb_optim.append(local_lamb_optim)
                user_val_num.append(local_data_num)

            self.global_lamb = sum([lamb_optim[i] * (user_val_num[i] / sum(user_val_num)) for i in range(len(lamb_optim))])
            self.global_lamb = torch.clamp(self.global_lamb, min=0.0, max=self.dual_params_bound).requires_grad_(False)

            print(f'\n global lamb:{self.global_lamb}.')

        elif self.options['fairness_type'] == 'subgroup':
            lamb_optim_1, lamb_optim_2 = [],[]
            user_val_num = []
            for c in selected_clients:
            # Communicate the latest global model
                c.global_lamb_1 = self.global_lamb_1
                c.global_lamb_2 = self.global_lamb_2
                local_lamb_optim_1, local_lamb_optim_2, local_data_num = c.FFACT_post_train()
                lamb_optim_1.append(local_lamb_optim_1)
                lamb_optim_2.append(local_lamb_optim_2)
                user_val_num.append(local_data_num)

            self.global_lamb_1 = sum([lamb_optim_1[i] * (user_val_num[i] / sum(user_val_num)) for i in range(len(lamb_optim_1))])
            self.global_lamb_1 = torch.clamp(self.global_lamb_1, min=0.0, max=self.dual_params_bound).requires_grad_(False)

            self.global_lamb_2 = sum([lamb_optim_2[i] * (user_val_num[i] / sum(user_val_num)) for i in range(len(lamb_optim_2))])
            self.global_lamb_2 = torch.clamp(self.global_lamb_2, min=0.0, max=self.dual_params_bound).requires_grad_(False)

            print(f'\n global lamb_1:{self.global_lamb_1}.')
            print(f' global lamb_2:{self.global_lamb_2}.')

        return True
    
    def post_test(self, clients, current_round):
        begin_time = time.time()
        stats_dict = {'num_samples':[], 'accs':[],'local_confusion':[],'local_fair_measure':[]}
        client_stats= {'val':copy.deepcopy(stats_dict),
                       'test' :copy.deepcopy(stats_dict)}
        for c in clients:
            if self.options['fairness_type'] == 'groupwise':
                c.global_lamb = self.global_lamb
            elif self.options['fairness_type'] == 'subgroup':
                c.global_lamb_1 = self.global_lamb_1
                c.global_lamb_2 = self.global_lamb_2

            client_eval_dict = c.local_post_eval()

            for dataset, test_dict in client_eval_dict.items():
                # dataset is val and test
                client_stats[dataset]['num_samples'].append(test_dict['num'])
                client_stats[dataset]['accs'].append(test_dict['acc'])
                client_stats[dataset]['local_fair_measure'].append(test_dict['local_fair_measure'])
                client_stats[dataset]['local_confusion'].append(test_dict['local_confusion'])

        end_time = time.time()

        val_acc = sum(client_stats['val']['accs']) / sum(client_stats['val']['num_samples'])
        val_fairness = self.global_fairness(client_stats['val'], 'val')
        
        test_acc = sum(client_stats['test']['accs']) / sum(client_stats['test']['num_samples'])
        test_fairness = self.global_fairness(client_stats['test'], 'test')

        # local_fair = [float(client_stats['test']['local_fair_measure'][c]) for c in range(self.num_users)]
        # avg_local_fair = sum(local_fair)/len(local_fair)

        model_stats = {'val_acc':val_acc, 'val_fairness':val_fairness,
                        'test_acc':test_acc, 'test_fairness':test_fairness,
                        'time': end_time - begin_time}
        
        # if self.print:
        #     self.print_result(current_round, model_stats, client_stats)
        self.post_print_result(current_round, model_stats, client_stats)

        return model_stats
    
    def post_print_result(self, current_round, model_stats, client_stats):
        tqdm.write("\n >>> Round: {: >4d} / Acc: {:.3%} / Time: {:.2f}s /  Fairness(val): {} ".format(
            current_round, model_stats['val_acc'], model_stats['time'], model_stats['val_fairness']))
        if self.options['fairness_type'] == 'groupwise':
            for c in range(self.num_users):
                tqdm.write('>>> Client: {} / Acc: {:.3%} / Fair_DP(val): {:.4f}/Fair_EO(val): {:.4f}'.format(
                            self.val_data['users'][c], 
                            client_stats['val']['accs'][c] / client_stats['val']['num_samples'][c], 
                            client_stats['val']['local_fair_measure'][c]['DP'],
                            client_stats['val']['local_fair_measure'][c]['EO'],
                        ))
            tqdm.write('=' * 102 + "\n")

            if current_round % self.eval_round == 0:
                tqdm.write("\n = Test = round: {} / acc: {:.3%} / Fairness(test): {} ".format(
                    current_round, model_stats['test_acc'], model_stats['test_fairness']))
                for c in range(self.num_users):
                    tqdm.write('=== Test  Client: {} / Acc: {:.3%} / Fair_DP:{:.4f} / Fair_EO:{:.4f}'.format(
                                self.test_data['users'][c], 
                                client_stats['test']['accs'][c] / client_stats['test']['num_samples'][c], 
                                client_stats['test']['local_fair_measure'][c]['DP'],
                                client_stats['test']['local_fair_measure'][c]['EO']
                            ))
                M_fair = {}
                for fair_notion in ["DP","EO"]: 
                    local_fair = [client_stats['test']['local_fair_measure'][c][fair_notion] for c in range(self.num_users)]
                    # fair_metric = max(local_fair) if fair_notion == 'DP' else sum(local_fair) / len(local_fair)
                    fair_metric = sum(local_fair)/ len(local_fair)
                    M_fair.update({fair_notion:fair_metric})
                tqdm.write(f"=== Test local fair: {M_fair}")
        elif self.options['fairness_type'] == 'subgroup':
            for c in range(self.num_users):
                tqdm.write('>>> Client: {} / Acc: {:.3%} / Fair_multi_DP(val): {:.4f}/Fair_multi_EOP(val): {:.4f}'.format(
                            self.val_data['users'][c], 
                            client_stats['val']['accs'][c] / client_stats['val']['num_samples'][c], 
                            client_stats['val']['local_fair_measure'][c]['multi_DP'],
                            client_stats['val']['local_fair_measure'][c]['multi_EOP'],
                        ))
            tqdm.write('=' * 102 + "\n")

            if current_round % self.eval_round == 0:
                tqdm.write("\n = Test = round: {} / acc: {:.3%} / Fairness: {}  ".format(
                    current_round, model_stats['test_acc'], model_stats['test_fairness']))
                for c in range(self.num_users):
                    tqdm.write('=== Test  Client: {} / Acc: {:.3%} / Fair_multi_DP(test):{:.4f} / Fair_multi_EOP(test):{:.4f}'.format(
                                self.test_data['users'][c], 
                                client_stats['test']['accs'][c] / client_stats['test']['num_samples'][c], 
                                client_stats['test']['local_fair_measure'][c]['multi_DP'],
                                client_stats['test']['local_fair_measure'][c]['multi_EOP'],
                            ))
                M_fair = {}
                for fair_notion in ["multi_DP","multi_EOP"]: 
                    local_fair = [client_stats['test']['local_fair_measure'][c][fair_notion] for c in range(self.num_users)]
                    # fair_metric = max(local_fair) if fair_notion == "multi_DP" else sum(local_fair)/ len(local_fair)
                    fair_metric = sum(local_fair)/ len(local_fair)
                    M_fair.update({fair_notion:fair_metric})
                tqdm.write(f"=== Test local fair: {M_fair}")

    def test(self, clients, current_round, ensemble=False):
        begin_time = time.time()
        client_stats, ids = self.local_test(clients, ensemble)
        end_time = time.time()

        train_acc = sum(client_stats['train']['accs']) / sum(client_stats['train']['num_samples'])
        train_loss = sum(client_stats['train']['losses']) / sum(client_stats['train']['num_samples'])
        train_fairness = self.global_fairness(client_stats['train'], 'train')
        
        test_acc = sum(client_stats['test']['accs']) / sum(client_stats['test']['num_samples'])
        test_loss = sum(client_stats['test']['losses']) / sum(client_stats['test']['num_samples'])
        test_fairness = self.global_fairness(client_stats['test'], 'test')

        model_stats = {'train_loss':train_loss, 'train_acc':train_acc, 'train_fairness':train_fairness,
                        'test_acc':test_acc, 'test_loss':test_loss, 'test_fairness':test_fairness,
                        'time': end_time - begin_time}
        
        # if self.print:
        #     self.print_result(current_round, model_stats, client_stats)
        self.print_result(current_round, model_stats, client_stats)

        return model_stats
    
    def local_test(self, clients, ensemble=False):
        assert self.latest_model is not None

        stats_dict = {'num_samples':[], 'accs':[], 'losses':[],'local_fair_measure':[],'local_confusion':[]}
        client_stats= {'train':copy.deepcopy(stats_dict), 
                       'test' :copy.deepcopy(stats_dict)}

        for c in clients:
            if self.test_local:
                c.set_params(c.local_params)
            else:
                c.set_params(self.latest_model)
            client_test_dict = c.local_eval(ensemble)

            for dataset, test_dict in client_test_dict.items():
                # dataset is train and test
                client_stats[dataset]['num_samples'].append(test_dict['num'])
                client_stats[dataset]['accs'].append(test_dict['acc'])
                client_stats[dataset]['losses'].append(test_dict['loss'])
                client_stats[dataset]['local_fair_measure'].append(test_dict['local_fair_measure'])
                client_stats[dataset]['local_confusion'].append(test_dict['local_confusion'])

        ids = [c.cid for c in clients]

        return client_stats, ids
    
    def iterate(self):
        selected_clients = self.select_clients(self.clients, self.current_round)

        # Do local update for the selected clients
        solns, stats = [], []
        for c in selected_clients:
            # Communicate the latest global model
            c.set_params(self.latest_model)

            # Solve local and personal minimization
            soln, stat = c.local_train()
            
            if self.print:
                tqdm.write('>>> Round: {: >2d} local acc | CID:{}| loss {:>.4f} | Acc {:>5.2f}% | Time: {:>.2f}s'.format(
                    self.current_round, c.cid, stat['loss'], stat['acc'] * 100, stat['time']
                    ))
                        
            # Add solutions and stats
            solns.append(soln)
            stats.append(stat)
    
        self.latest_model = self.aggregate(solns, seed = self.current_round, stats = stats)
        return True

    def aggregate(self, solns, seed, stats):
        averaged_solution = torch.zeros_like(self.latest_model)

        num_samples, chosen_solns = [info[0] for info in solns], [info[1] for info in solns]
        if self.aggr == 'mean':  
            if self.simple_average:
                num = 0
                for num_sample, local_soln in zip(num_samples, chosen_solns):
                    num += 1
                    averaged_solution += local_soln
                averaged_solution /= num
            else:      
                selected_sample = 0
                for num_sample, local_soln in zip(num_samples, chosen_solns):
                    # print(num_sample)
                    averaged_solution += num_sample * local_soln
                    selected_sample += num_sample
                    # print("local_soln:{},num_sample:{}".format(local_soln, num_sample))
                averaged_solution = averaged_solution / selected_sample
                # print(averaged_solution)

        elif self.aggr == 'median':
            stack_solution = torch.stack(chosen_solns)
            averaged_solution = torch.median(stack_solution, dim = 0)[0]
        elif self.aggr == 'krum':
            f = int(len(chosen_solns) * 0)
            dists = torch.zeros(len(chosen_solns), len(chosen_solns))
            scores = torch.zeros(len(chosen_solns))
            for i in range(len(chosen_solns)):
                for j in range(i, len(chosen_solns)):
                    dists[i][j] = torch.norm(chosen_solns[i] - chosen_solns[j], p = 2)
                    dists[j][i] = dists[i][j]
            for i in range(len(chosen_solns)):
                d = dists[i]
                d, _ = d.sort()
                scores[i] = d[:len(chosen_solns) - f - 1].sum()
            averaged_solution = chosen_solns[torch.argmin(scores).item()]
                
        averaged_solution = (1 - self.beta) * self.latest_model + self.beta * averaged_solution
        return averaged_solution.detach()
    