from cmath import log
import copy
import logging
import random

import numpy as np
import torch
import wandb
import os
from .client import Client
from .my_model_trainer_classification import MyModelTrainer as MyModelTrainerCLS
from .my_model_trainer_nwp import MyModelTrainer as MyModelTrainerNWP
from .my_model_trainer_tag_prediction import MyModelTrainer as MyModelTrainerTAG
import logging
import pickle
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from sklearn.cluster import KMeans

class FedAvgAPI(object):
    def __init__(self, args, device, dataset, model,model_trainer=None):
        self.device = device
        self.args = args
        [
            train_data_num,
            test_data_num,
            train_data_global,
            test_data_global,
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            val_data_local_dict,
            class_num,
        ] = dataset
        self.train_global = train_data_global
        self.test_global = test_data_global
        self.val_global = None
        self.train_data_num_in_total = train_data_num
        self.test_data_num_in_total = test_data_num

        self.client_list = []
        self.train_data_local_num_dict = train_data_local_num_dict
        self.train_data_local_dict = train_data_local_dict
        self.test_data_local_dict = test_data_local_dict
        self.val_data_local_dict = val_data_local_dict

        logging.info("model = {}".format(model))
        if model_trainer is None:
            if args.dataset == "stackoverflow_lr":
                model_trainer = MyModelTrainerTAG(model)
            elif args.dataset in ["fed_shakespeare", "stackoverflow_nwp"]:
                model_trainer = MyModelTrainerNWP(model)
            else:
                # default model trainer is for classification problem
                model_trainer = MyModelTrainerCLS(model)
        self.model_trainer = model_trainer
        logging.info("self.model_trainer = {}".format(self.model_trainer))

        self._setup_clients(
            train_data_local_num_dict,
            train_data_local_dict,
            test_data_local_dict,
            val_data_local_dict,
            self.model_trainer,
        )

    def _setup_clients(
        self,
        train_data_local_num_dict,
        train_data_local_dict,
        test_data_local_dict,
        val_data_local_dict,
        model_trainer,
    ):
        logging.info("############setup_clients (START)#############")
        for client_idx in self.args.users: 
            c = Client(
                client_idx,
                train_data_local_dict[client_idx],
                test_data_local_dict[client_idx],
                val_data_local_dict[client_idx],
                train_data_local_num_dict[client_idx],
                self.args,
                self.device,
                model_trainer,
            )
            self.client_list.append(c)
        logging.info("############setup_clients (END)#############")

    def train(self):
        logging.info("self.model_trainer = {}".format(self.model_trainer))
        w_global = self.model_trainer.get_model_params()
        
        plot_gap_using_local = np.empty([self.args.comm_round, self.args.client_num_in_total])
        plot_acc_using_local = np.empty([self.args.comm_round, self.args.client_num_in_total])     
        plot_gap_using_global = np.empty([self.args.comm_round, self.args.client_num_in_total])
        plot_acc_using_global = np.empty([self.args.comm_round, self.args.client_num_in_total])
        plot_aggregate = np.empty([self.args.comm_round, self.args.client_num_in_total])
        plot_global_fairness = np.empty([self.args.comm_round])
        plot_global_acc = np.empty([self.args.comm_round])
        plot_local_acc_using_global_model = np.empty([self.args.comm_round, self.args.client_num_in_total])
        identity = np.empty([self.args.comm_round, self.args.client_num_in_total])
        plot_gap_using_local[:] = np.nan
        plot_acc_using_local[:] = np.nan
        plot_gap_using_global[:] = np.nan
        plot_acc_using_global[:] = np.nan
        plot_aggregate[:] = np.nan
        w_save = []

        
        for round_idx in range(self.args.comm_round):
            
            logging.info("################Communication round : {}".format(round_idx))

            w_locals = []
            w_locals_0 = []
            w_locals_1 = []
            
            client_indexes = self._client_sampling(
                round_idx, self.args.client_num_in_total, self.args.client_num_per_round
            )
            logging.info("client_indexes = " + str(client_indexes))

            for idx, client_idx in enumerate(client_indexes):
                client = self.client_list[idx]  
                
                if round_idx > 1:
                    
                    if client_idx in new_index:
                
                        w = client.train(copy.deepcopy(w_global_0))
                        identity[round_idx, idx] = 0
                        
                    else :
                        
                        w = client.train(copy.deepcopy(w_global_1))
                        identity[round_idx, idx] = 1
                     
                else:                     
                    # round = 0
                    w = client.train(copy.deepcopy(w_global))
                    
                w_locals.append((client.get_sample_number(), copy.deepcopy(w)))
                w_save.append(copy.deepcopy(w))

            if round_idx < 1:
                w_global = self._aggregate(w_locals,round_idx)
                w_global_0 = w_global
                w_global_1 = w_global
            else :
                w_p0_local = []
                w_p1_local = []
                new_index = _detect(loc_acc_global)
                for i in new_index:
                    w_p0_local.append(w_locals[i])
                w_global_0 = self._aggregate(w_p0_local,round_idx)
                for j not in new_index:
                    w_p1_local.append(w_locals[j])
                w_global_1 = self._aggregate(w_p1_local,round_idx)
            
            # save information
            if round_idx % self.args.save_epoches == 0: 
                torch.save(self.model_trainer.model.state_dict(),os.path.join(self.args.run_folder, "%s_at_%s.pt" %(self.args.save_model_name,round_idx))) # check the fedavg model name
                with open("%s/%s_locals_at_%s.pt" %(self.args.run_folder,self.args.save_model_name,round_idx),'wb') as f:
                    pickle.dump(w_save, f, protocol=pickle.HIGHEST_PROTOCOL)

            if round_idx == self.args.comm_round - 1 or round_idx % self.args.frequency_of_the_test == 0:
                # check dp_gap for all clients using global model
                plot_gap_each_global, dp_gap_test_global, test_accuracy, loc_acc_global = self._local_test_on_all_clients(round_idx, w_global, w_global_0, w_global_1)
                
                dp_gap_test_global, test_accuracy = self._local_test_on_all_clients_global(round_idx,w_global, w_global_0, w_global_1)
              
        
                plot_aggregate[round_idx] = plot_gap_each_global
                plot_global_fairness[round_idx] = dp_gap_test_global
                plot_global_acc[round_idx] = test_accuracy
                plot_local_acc_using_global_model[round_idx,:] = loc_acc_global
                
                
    def _detect(slef, loc_acc_global):
        loc_acc_global = np.array(loc_acc_global)
        km = KMeans(n_clusters=2)
        km.fit(loc_acc_global.reshape(-1,1))
        
        if np.count_nonzero(km.labels_ == [1]) >= np.count_nonzero(km.labels_ == [0]):
            index = [i for i, j in enumerate(km.labels_) if j == 0]
        else :
            index = [i for i, j in enumerate(km.labels_) if j == 1]
        
        return index
    
      
      
    def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
        if client_num_in_total == client_num_per_round:
            client_indexes = self.args.users
        else:
            num_clients = min(client_num_per_round, client_num_in_total)
            np.random.seed(
                round_idx
            )  # make sure for each comparison, we are selecting the same clients each round
            client_indexes = np.random.choice(
                self.args.users, num_clients, replace=False
            )
            np.random.seed(self.args.random_seed)
        logging.info("client_indexes = %s" % str(client_indexes))
        return client_indexes

    def _generate_validation_set(self, num_samples=10000):
        return False

    def _aggregate(self, w_locals,round_idx):
        
        training_num = 0
        for idx in range(len(w_locals)):
            (sample_num, averaged_params) = w_locals[idx]
            training_num += sample_num

        (sample_num, averaged_params) = w_locals[0]
        for k in averaged_params.keys():
            for i in range(0, len(w_locals)):
                local_sample_number, local_model_params = w_locals[i]
                w = local_sample_number / training_num
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w
        
        return averaged_params

    def _aggregate_noniid_avg(self, w_locals):
        # uniform aggregation
        """
        The old aggregate method will impact the model performance when it comes to Non-IID setting
        Args:
            w_locals:
        Returns:
        """
        (_, averaged_params) = w_locals[0]
        for k in averaged_params.keys():
            temp_w = []
            for (_, local_w) in w_locals:
                temp_w.append(local_w[k])
            averaged_params[k] = sum(temp_w) / len(temp_w)
        return averaged_params

    
    def _local_test_on_all_clients_global(self, round_idx, w_global, w_global_0, w_global_1):

        test_metrics = {"num_samples": [], "num_correct": [], "losses": [], "eo_gap":[],"dp_gap":[]}

        test_target_list_global = []
        test_pred_list_global = []
        test_s_list_global = []
        
        for key in w_global: 
                w_global[key] =  6/51 * w_global_0[key] + 45/51 * w_global_1[key]
            
        self.model_trainer.set_model_params(w_global)
        
        
        for idx,client_idx in enumerate(self.args.users):
            
            if self.test_data_local_dict[client_idx] is None:
                continue

            client = self.client_list[idx]

            test_local_metrics, test_target_list, test_pred_list, test_s_list = client.local_test(True)
            test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
            test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
            test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))
            test_metrics["eo_gap"].append(copy.deepcopy(test_local_metrics["eo_gap"]))
            test_metrics["dp_gap"].append(copy.deepcopy(test_local_metrics["dp_gap"]))
            test_target_list_global.append(test_target_list.tolist())
            test_pred_list_global.append(test_pred_list.tolist())
            test_s_list_global.append(test_s_list.tolist())
            
            
        test_target_list_global = np.array(sum(test_target_list_global,[]))
        test_pred_list_global = np.array(sum(test_pred_list_global,[]))
        test_s_list_global = np.array(sum(test_s_list_global,[]))

        pred_test_acc = ( test_pred_list_global==test_target_list_global)
        ppr_test_global = []
        tnr_test_global = []
        tpr_test_global = []
        converted_test_s = test_s_list_global[:,1] # sex, 1 attribute
        
        for s_value in np.unique(converted_test_s):
            if np.mean(converted_test_s == s_value) > 0.01:
                indexs0  = np.logical_and(test_target_list_global==0, converted_test_s==s_value)
                indexs1  = np.logical_and(test_target_list_global==1, converted_test_s==s_value)
                ppr_test_global.append(np.mean(test_pred_list_global[converted_test_s==s_value]))
                tnr_test_global.append(np.mean(pred_test_acc[indexs0]))
                tpr_test_global.append(np.mean(pred_test_acc[indexs1]))

        eo_gap_test_global = max(max(tnr_test_global)-min(tnr_test_global), max(tpr_test_global)-min(tpr_test_global))
        dp_gap_test_global = max(ppr_test_global) - min(ppr_test_global)

        test_global_acc = np.mean(pred_test_acc)

        return  dp_gap_test_global, test_global_acc

    
    def _local_test_on_all_clients(self, round_idx, w_global, w_global_0, w_global_1):

        logging.info("################local_test_on_all_clients : {}".format(round_idx))

        train_metrics = {"num_samples": [], "num_correct": [], "losses": [], "eo_gap":[],"dp_gap":[]}

        test_metrics = {"num_samples": [], "num_correct": [], "losses": [], "eo_gap":[],"dp_gap":[]}

        train_target_list_global = []
        train_pred_list_global = []
        train_s_list_global = []
        test_target_list_global = []
        test_pred_list_global = []
        test_s_list_global = []
        local_train_acc = []
        local_test_acc = []
        p_1_local = np.empty([self.args.comm_round, self.args.client_num_in_total])
        p_0_local = np.empty([self.args.comm_round, self.args.client_num_in_total])

        for idx,client_idx in enumerate(self.args.users):
            
            
            if self.test_data_local_dict[client_idx] is None:
                continue
                
            if round_idx >= 1:
            
                if client_idx in [1, 10, 20, 30, 40, 50]:
                    self.model_trainer.set_model_params(w_global_0)
                else: 
                    self.model_trainer.set_model_params(w_global_1)
            else:
                self.model_trainer.set_model_params(w_global)
            
            client = self.client_list[idx]
            train_local_metrics, train_target_list, train_pred_list, train_s_list = client.local_test(False)
            train_metrics["num_samples"].append(copy.deepcopy(train_local_metrics["test_total"]))
            train_metrics["num_correct"].append(copy.deepcopy(train_local_metrics["test_correct"]))
            train_metrics["losses"].append(copy.deepcopy(train_local_metrics["test_loss"]))
            train_metrics["eo_gap"].append(copy.deepcopy(train_local_metrics["eo_gap"]))
            train_metrics["dp_gap"].append(copy.deepcopy(train_local_metrics["dp_gap"]))
            train_target_list_global.append(train_target_list.tolist())
            train_pred_list_global.append(train_pred_list.tolist())
            train_s_list_global.append(train_s_list.tolist())
            local_train_acc.append(train_local_metrics["test_correct"]/train_local_metrics["test_total"])
            
            test_local_metrics, test_target_list, test_pred_list, test_s_list = client.local_test(True)
            test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
            test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
            test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))
            test_metrics["eo_gap"].append(copy.deepcopy(test_local_metrics["eo_gap"]))
            test_metrics["dp_gap"].append(copy.deepcopy(test_local_metrics["dp_gap"]))
            test_target_list_global.append(test_target_list.tolist())
            test_pred_list_global.append(test_pred_list.tolist())
            test_s_list_global.append(test_s_list.tolist())
            local_test_acc.append(test_local_metrics["test_correct"]/test_local_metrics["test_total"])
            p_1_local[round_idx, idx] = np.mean(np.logical_and(test_pred_list==1, test_s_list[:,1]==1))
            p_0_local[round_idx, idx] = np.mean(np.logical_and(test_pred_list==1, test_s_list[:,1]==0))
            

        train_target_list_global = np.array(sum(train_target_list_global, []))
        train_pred_list_global = np.array(sum(train_pred_list_global,[]))
        train_s_list_global = np.array(sum(train_s_list_global,[]))
        test_target_list_global = np.array(sum(test_target_list_global,[]))
        test_pred_list_global = np.array(sum(test_pred_list_global,[]))
        test_s_list_global = np.array(sum(test_s_list_global,[]))

        pred_train_acc = ( train_pred_list_global==train_target_list_global)
        pred_test_acc = ( test_pred_list_global==test_target_list_global)
        ppr_train_global = []
        tnr_train_global = []
        tpr_train_global = []
        ppr_test_global = []
        tnr_test_global = []
        tpr_test_global = []
        converted_train_s = train_s_list_global[:,1] # sex, 1 attribute
        converted_test_s = test_s_list_global[:,1] # sex, 1 attribute
        
        for s_value in np.unique(converted_train_s):
            if np.mean(converted_train_s == s_value) > 0.01:
                indexs0  = np.logical_and(train_target_list_global==0, converted_train_s==s_value)
                indexs1  = np.logical_and(train_target_list_global==1, converted_train_s==s_value)
                ppr_train_global.append(np.mean(train_pred_list_global[converted_train_s==s_value]))
                tnr_train_global.append(np.mean(pred_train_acc[indexs0]))
                tpr_train_global.append(np.mean(pred_train_acc[indexs1]))
               
        for s_value in np.unique(converted_test_s):
            if np.mean(converted_test_s == s_value) > 0.01:
                indexs0  = np.logical_and(test_target_list_global==0, converted_test_s==s_value)
                indexs1  = np.logical_and(test_target_list_global==1, converted_test_s==s_value)
                ppr_test_global.append(np.mean(test_pred_list_global[converted_test_s==s_value]))
                tnr_test_global.append(np.mean(pred_test_acc[indexs0]))
                tpr_test_global.append(np.mean(pred_test_acc[indexs1]))

        eo_gap_train_global = max(max(tnr_train_global)-min(tnr_train_global), max(tpr_train_global)-min(tpr_train_global))
        dp_gap_train_global = max(ppr_train_global) - min(ppr_train_global)
        eo_gap_test_global = max(max(tnr_test_global)-min(tnr_test_global), max(tpr_test_global)-min(tpr_test_global))
        dp_gap_test_global = max(ppr_test_global) - min(ppr_test_global)

        train_global_acc = np.mean(pred_train_acc)
        test_global_acc = np.mean(pred_test_acc)
        
        # test on training dataset
        train_acc = sum(train_metrics["num_correct"]) / sum(train_metrics["num_samples"])
        train_loss = sum(train_metrics["losses"]) / sum(train_metrics["num_samples"])
        train_dp_gap = sum(train_metrics["dp_gap"])/len(self.args.users)
        train_eo_gap = sum(train_metrics["eo_gap"])/len(self.args.users)

        # test on test dataset
        test_acc = sum(test_metrics["num_correct"]) / sum(test_metrics["num_samples"])
        test_loss = sum(test_metrics["losses"]) / sum(test_metrics["num_samples"])
        test_dp_gap = sum(test_metrics["dp_gap"])/len(self.args.users)
        test_eo_gap = sum(test_metrics["eo_gap"])/len(self.args.users)
        logging.info('dp_gap' + str(train_metrics["dp_gap"]))
        logging.info('Train acc: {} Train Loss: {}, Test acc: {} Test Loss: {}'.format(train_acc, train_loss, test_acc,test_loss))
        logging.info('Train dp gap: {} Train eo gap: {}, Test dp gap: {} Test eo gap: {}'.format(train_dp_gap,train_eo_gap, test_dp_gap,test_eo_gap))

        if self.args.enable_wandb:
            wandb.log({"Test/Acc": test_acc, "round": round_idx})
            wandb.log({"Test/Loss": test_loss, "round": round_idx})
            wandb.log({"Train/Acc": train_acc, "round": round_idx})
            wandb.log({"Train/Loss": train_loss, "round": round_idx})
       
        
        return test_metrics["dp_gap"], dp_gap_test_global, test_global_acc, local_test_acc

    
    
    
    def _local_test_on_validation_set(self, round_idx):

        logging.info(
            "################local_test_on_validation_set : {}".format(round_idx)
        )

        if self.val_global is None:
            self._generate_validation_set()

        client = self.client_list[0]
        client.update_local_dataset(0, None, self.val_global, None)
        # test data
        test_metrics = client.local_test(True)

        if self.args.dataset == "stackoverflow_nwp":
            test_acc = test_metrics["test_correct"] / test_metrics["test_total"]
            test_loss = test_metrics["test_loss"] / test_metrics["test_total"]
            stats = {"test_acc": test_acc, "test_loss": test_loss}
            if self.args.enable_wandb:
                wandb.log({"Test/Acc": test_acc, "round": round_idx})
                wandb.log({"Test/Loss": test_loss, "round": round_idx})
        elif self.args.dataset == "stackoverflow_lr":
            test_acc = test_metrics["test_correct"] / test_metrics["test_total"]
            test_pre = test_metrics["test_precision"] / test_metrics["test_total"]
            test_rec = test_metrics["test_recall"] / test_metrics["test_total"]
            test_loss = test_metrics["test_loss"] / test_metrics["test_total"]
            stats = {
                "test_acc": test_acc,
                "test_pre": test_pre,
                "test_rec": test_rec,
                "test_loss": test_loss,
            }
            if self.args.enable_wandb:
                wandb.log({"Test/Acc": test_acc, "round": round_idx})
                wandb.log({"Test/Pre": test_pre, "round": round_idx})
                wandb.log({"Test/Rec": test_rec, "round": round_idx})
                wandb.log({"Test/Loss": test_loss, "round": round_idx})
        else:
            raise Exception(
                "Unknown format to log metrics for dataset {}!" % self.args.dataset
            )

        logging.info(stats)

    def save(self):
        torch.save(self.model_trainer.model.state_dict(),os.path.join(self.args.run_folder, "%s.pt" %(self.args.save_model_name))) # check the fedavg model name

