import copy
import logging
import pickle
import numpy as np
import torch
import wandb
from fedml_api.standalone.fedspa.client import Client
from fedml_api.standalone.fedspa.slim_util import model_difference


class FedSpaAPI(object):
    def __init__(self, dataset, device, args, model_trainer):
        self.device = device
        self.args = args
        [train_data_num, test_data_num, train_data_global, test_data_global,
         train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_counts] = dataset
        self.train_global = train_data_global
        self.test_global = test_data_global
        self.val_global = None
        self.train_data_num_in_total = train_data_num
        self.test_data_num_in_total = test_data_num
        self.client_list = []
        self.train_data_local_num_dict = train_data_local_num_dict
        self.train_data_local_dict = train_data_local_dict
        self.test_data_local_dict = test_data_local_dict
        self.class_counts = class_counts
        self.model_trainer = model_trainer
        self._setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer)
        self.init_stat_info()

    def _setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer):
        logging.info("############setup_clients (START)#############")
        for client_idx in range(self.args.client_num_per_round):
            c = Client(client_idx, train_data_local_dict[client_idx], test_data_local_dict[client_idx],
                       train_data_local_num_dict[client_idx], self.args, self.device, model_trainer)
            self.client_list.append(c)
        logging.info("############setup_clients (END)#############")

    def train(self):

        params = self.model_trainer.get_trainable_params()
        if self.args.uniform:
            sparsities = self.model_trainer.calculate_sparsities(params,distribution="uniform")
        else:
            sparsities = self.model_trainer.calculate_sparsities(params)
        if not self.args.different_initial:
            temp = self.model_trainer.init_masks(params, sparsities)
            mask_pers = [copy.deepcopy(temp) for i in range(self.args.client_num_in_total)]
        else:
            mask_pers = [copy.deepcopy(self.model_trainer.init_masks(params, sparsities)) for i in range(self.args.client_num_in_total)]
        w_global = self.model_trainer.get_model_params()

        for round_idx in range(self.args.comm_round):
            logging.info("################Communication round : {}".format(round_idx))
            self.record_mask_difference(mask_pers)
            w_locals = []
            # random client sampling
            client_indexes = self._client_sampling(round_idx, self.args.client_num_in_total,
                                                   self.args.client_num_per_round)
            logging.info("client_indexes = " + str(client_indexes))
            # next_masks = []
            for idx, client in enumerate(self.client_list):
                # update dataset
                client_idx = client_indexes[idx]
                client.update_local_dataset(client_idx, self.train_data_local_dict[client_idx],
                                            self.test_data_local_dict[client_idx],
                                            self.train_data_local_num_dict[client_idx])
                w_per = copy.deepcopy(w_global)
                for name in mask_pers[client_idx]:
                    w_per[name] = w_global[name] * mask_pers[client_idx][name]
                new_mask, update, training_flops, num_comm_params = client.train(w_per, copy.deepcopy(mask_pers[client_idx]), round_idx)
                # self.logger.info("local weights = " + str(w))
                w_locals.append((copy.deepcopy(mask_pers[client_idx]), copy.deepcopy(update), copy.deepcopy(client.local_sample_number)))

                mask_pers[client_idx]= copy.deepcopy(new_mask)
                self.stat_info["sum_training_flops"] += training_flops
                # 2 x communication involving update and download
                self.stat_info["sum_comm_params"] +=  num_comm_params


            w_global = self._strict_avg_aggregate(w_global, w_locals)

            # test results at last round
            if round_idx == self.args.comm_round - 1:
                self._local_test_on_all_clients(w_global, mask_pers, round_idx)
            elif round_idx % self.args.frequency_of_the_test == 0:
                self._local_test_on_all_clients(w_global, mask_pers, round_idx)

        if self.args.save_masks:
            saved_masks = [{} for index in range(len(mask_pers))]
            for index, mask in enumerate(mask_pers):
                for name in mask:
                        saved_masks[index][name] = mask[name].data.bool()
            self.stat_info["final_masks"] =saved_masks

        self.record_avg_inference_flops(w_global, mask_pers)
        self.record_information()
        return


    def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
        if client_num_in_total == client_num_per_round:
            client_indexes = [client_index for client_index in range(client_num_in_total)]
        else:
            num_clients = min(client_num_per_round, client_num_in_total)
            np.random.seed(round_idx)  # make sure for each comparison, we are selecting the same clients each round
            client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)
        logging.info("client_indexes = %s" % str(client_indexes))
        return client_indexes



    def _strict_avg_aggregate(self, w_server,w_locals):
        print("strict_avg")
        update= {}
        (_, temp, sample_num) = w_locals[0]
        for k in temp.keys():
            for i in range(0, len(w_locals)):
                _, local_model_update,sample_num = w_locals[i]
                w = 1 / len(w_locals)
                if i == 0:
                    update[k] = local_model_update[k] * w
                else:
                    update[k] += local_model_update[k] * w
        w_server = copy.deepcopy(w_server)
        for name in w_server:
            w_server[name] -= update[name]
        return w_server

    def _local_test_on_all_clients(self, w_global, mask_pers, round_idx):

            logging.info("################local_test_on_all_clients : {}".format(round_idx))

            train_metrics = {
                'num_samples': [],
                'num_correct': [],
                'losses': []
            }

            test_metrics = {
                'num_samples': [],
                'num_correct': [],
                'losses': []
            }

            test_global= {
                'num_samples': [],
                'num_correct': [],
                'losses': []
            }
            client = self.client_list[0]

            for client_idx in range(self.args.client_num_in_total):
                """
                Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
                the training client number is larger than the testing client number
                """
                if self.test_data_local_dict[client_idx] is None:
                    continue
                client.update_local_dataset(0, self.train_data_local_dict[client_idx],
                                            self.test_data_local_dict[client_idx],
                                            self.train_data_local_num_dict[client_idx])
                w_per = copy.deepcopy(w_global)
                for name in mask_pers[client_idx]:
                    w_per[name] = w_global[name] * mask_pers[client_idx][name]
                # # # train data
                # train_local_metrics = client.local_test(w_per,False)
                # train_metrics['num_samples'].append(copy.deepcopy(train_local_metrics['test_total']))
                # train_metrics['num_correct'].append(copy.deepcopy(train_local_metrics['test_correct']))
                # train_metrics['losses'].append(copy.deepcopy(train_local_metrics['test_loss']))

                # test data
                test_local_metrics = client.local_test(w_per, True)
                test_metrics['num_samples'].append(copy.deepcopy(test_local_metrics['test_total']))
                test_metrics['num_correct'].append(copy.deepcopy(test_local_metrics['test_correct']))
                test_metrics['losses'].append(copy.deepcopy(test_local_metrics['test_loss']))

                if self.args.global_test:
                    #test global
                    test_global_metrics = client.local_test(w_global, True)
                    test_global['num_samples'].append(copy.deepcopy(test_global_metrics['test_total']))
                    test_global['num_correct'].append(copy.deepcopy(test_global_metrics['test_correct']))
                    test_global['losses'].append(copy.deepcopy(test_global_metrics['test_loss']))
                """
                Note: CI environment is CPU-based computing. 
                The training speed for RNN training is to slow in this setting, so we only test a client to make sure there is no programming error.
                """
                if self.args.ci == 1:
                    break

            # test on training dataset
            # train_acc = sum(train_metrics['num_correct']) / sum(train_metrics['num_samples'])
            # train_loss = sum(train_metrics['losses']) / sum(train_metrics['num_samples'])

            # # test on test dataset
            test_acc = sum([test_metrics['num_correct'][i] / test_metrics['num_samples'][i] for i in range(self.args.client_num_in_total) ] )/self.args.client_num_in_total
            test_loss = sum([np.array(test_metrics['losses'][i]) / np.array(test_metrics['num_samples'][i]) for i in range(self.args.client_num_in_total)])/self.args.client_num_in_total
            if self.args.global_test:
                #global test
                test_global_acc = sum([test_global['num_correct'][i] / test_global['num_samples'][i] for i in
                                range(self.args.client_num_in_total)]) / self.args.client_num_in_total
                test_global_loss = sum([np.array(test_global['losses'][i]) / np.array(test_global['num_samples'][i]) for i in
                                 range(self.args.client_num_in_total)]) / self.args.client_num_in_total
            # stats = {'training_acc': train_acc, 'training_loss': train_loss}
            # wandb.log({"Train/Acc": train_acc, "round": round_idx})
            # wandb.log({"Train/Loss": train_loss, "round": round_idx})
            # logging.info(stats)

            stats = {'test_acc': test_acc, 'test_loss': test_loss}
            wandb.log({"Test/Acc": test_acc, "round": round_idx})
            wandb.log({"Test/Loss": test_loss, "round": round_idx})
            logging.info(stats)
            self.stat_info["test_acc"].append(test_acc)

            if self.args.global_test:
                stats = {'global_acc': test_global_acc, 'global_loss': test_global_loss}
                # wandb.log({"Test/Acc": test_global_acc, "round": round_idx})
                # wandb.log({"Test/Loss": test_global_acc, "round": round_idx})
                logging.info(stats)
                self.stat_info["global_model_acc"].append(test_global_acc)






    def record_mask_difference(self, mask_pers):
        mean = {}
        for name in mask_pers[0]:
            mean[name] = torch.zeros_like(mask_pers[0][name])
            for mask in mask_pers:
                mean[name] += mask[name]
            mean[name] /= len(mask_pers)
        distance = 0
        for mask in mask_pers:
            distance += model_difference(mask, mean)
        wandb.log({"mask_distance": distance})
        self.stat_info["mask_difference"].append(distance)
        logging.info("distance{}".format(distance))

    def  record_information(self):
        if self.args.comm_round>50:
            path = "../../results/"+self.args.dataset+"/"+self.args.identity
            output = open(path, 'wb')
            pickle.dump(self.stat_info,output)

    def record_avg_inference_flops(self, w_global, mask_pers=None):
        inference_flops=[]
        for client_idx in range(self.args.client_num_in_total):

            if mask_pers== None:
                inference_flops += [self.model_trainer.count_inference_flops(w_global)]
            else:
                w_per = {}
                for name in mask_pers[client_idx]:
                    w_per[name] = w_global[name] *mask_pers[client_idx][name]
                inference_flops+= [self.model_trainer.count_inference_flops(w_per)]
        avg_inference_flops = sum(inference_flops)/len(inference_flops)
        self.stat_info["avg_inference_flops"]=avg_inference_flops





    def init_stat_info(self, ):
        self.stat_info = {}
        self.stat_info["label_num"] =self.class_counts
        # self.stat_info["client_training_flops"] = [0 for i in range(len(self.args.client_num_in_total))]
        # self.stat_info["client_comm_params"] = [ 0 for i in range(len(self.args.client_num_in_total))]
        self.stat_info["sum_comm_params"] = 0
        self.stat_info["sum_training_flops"] = 0
        self.stat_info["avg_inference_flops"] = 0
        self.stat_info["test_acc"] = []
        self.stat_info["final_masks"] = []
        self.stat_info["mask_difference"] = []
        self.stat_info["global_model_acc"] = []


