import logging
import statistics
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb
from fedml_api.model.cv.darts import utils
from torch import nn


class FedNASAggregator(object):

    def __init__(self, train_global, test_global, train_local_dict, test_local_dict, all_train_data_num, client_num,
                 model, device, args, client_num_in_total):
        # data
        self.train_global = train_global
        self.test_global = test_global
        self.train_local_dict = train_local_dict
        self.test_local_dict = test_local_dict
        self.all_train_data_num = all_train_data_num

        self.client_num = client_num
        self.client_num_in_total = client_num_in_total
        self.device = device
        self.args = args
        self.model = model

        self.pi_params_dict = dict()
        self.model_dict = dict()
        self.arch_dict = dict()
        self.sample_num_dict = dict()
        self.train_acc_dict = dict()
        self.personalized_architecture = dict()

        # personalized
        self.acc_local_test_data_with_personalized_model_dict = dict()
        self.acc_avg_personalized_model_on_all_clients = 0.0

        # flops and model size count
        self.flops_local_model_all_clients = dict()
        self.flops_local_model_this_round = dict()
        self.model_size_local_model_all_clients = dict()
        self.model_size_local_model_this_round = dict()
        self.flops_avg_personalized_model_on_all_clients = 0.0
        self.model_size_avg_personalized_model_on_all_clients = 0.0


        # test accuracy on global test data
        self.train_loss_dict = dict()
        self.test_acc_dict = dict()
        self.acc_avg_train_global = 0.0
        self.loss_avg_test_global = 0.0
        self.acc_avg_test_global = 0.0
        self.acc_avg_test_global_best = 0.0  # best
        self.best_accuracy_different_cnn_counts = dict()

        self.acc_local_test_data_with_global_model_dict = dict()
        self.acc_avg_global_model_on_all_clients = 0.0

        self.acc_union_local_test_data = 0.0
        self.acc_union_local_test_data_best = 0.0  # best
        self.best_acc_prev_round = 0.0
        self.clients_of_this_round = dict()
        self.global_model_acc_after_finetunning = dict()

        self.flag_client_model_uploaded_dict = dict()
        for idx in range(self.client_num):
            self.flag_client_model_uploaded_dict[idx] = False

        self.flag_client_test_result_uploaded_dict = dict()
        for idx in range(self.client_num_in_total):
            self.flag_client_test_result_uploaded_dict[idx] = False
        self.step_three_valid_acc = dict()

    def get_device(self):
        return self.device

    def get_model(self):
        return self.model

    def get_local_datasets(self):
        return  self.train_local_dict, self.test_local_dict

    def client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
        if client_num_in_total == client_num_per_round:
            client_indexes = [client_index for client_index in range(client_num_in_total)]
        else:
            if self.args.client_sampling is False:
                client_indexes = [client_index for client_index in range(client_num_per_round)]
            else:
                num_clients = min(client_num_per_round, client_num_in_total)
                np.random.seed(round_idx)  # make sure for each comparison, we are selecting the same clients each round
                client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)
        logging.info("client_indexes = %s" % str(client_indexes))
        return client_indexes

    def add_local_trained_result(self, index, model_params, arch_params, sample_num, train_acc, train_loss, personalized_acc, client_index,
                                 flops, model_size, round_idx):
        logging.info("add_model. sample_num = %d, index = %d, round index = %d, freq = %d" % (sample_num, index, round_idx, self.args.frequency_of_the_test))
        # logging.info("arch parameters at aggregator side ")
        # logging.info(arch_params)

        self.sample_num_dict[index] = sample_num

        # the accuracy/loss that is tested on each client's local test/train dataset using the ***personalized*** model
        self.model_dict[index] = model_params
        self.arch_dict[index] = arch_params
        self.train_acc_dict[index] = train_acc
        self.train_loss_dict[index] = train_loss
        logging.info("Mod "+str(round_idx % self.args.frequency_of_the_test))
        if round_idx % self.args.frequency_of_the_test == 0:
            # logging.info("Round value of reportinggggggg"+str(round_idx))
            self.test_acc_dict[index] = personalized_acc
            self.acc_local_test_data_with_personalized_model_dict[client_index] = personalized_acc
        self.clients_of_this_round[index] = client_index
        self.flops_local_model_all_clients[client_index] = flops
        self.flops_local_model_this_round[index] = flops
        self.model_size_local_model_all_clients[client_index] = model_size
        self.model_size_local_model_this_round[index] = model_size

        self.flag_client_model_uploaded_dict[index] = True

    def add_local_fine_tuned_result(self, index, model_params, arch_params, sample_num, train_acc, train_loss,
                                    acc_global_model_on_local_data, acc_personalized_model_on_local_data,
                                    pi_params, personalized_arch, client_index):
        self.pi_params_dict[client_index] = pi_params
        self.personalized_architecture[client_index] = personalized_arch

        self.acc_local_test_data_with_personalized_model_dict[client_index] = acc_personalized_model_on_local_data

        # global model's accuracy on local test data
        self.acc_local_test_data_with_global_model_dict[client_index] = acc_global_model_on_local_data
        # self.global_model_acc_after_finetunning[client_index] = global_model_acc_after_finetunning

        assert self.flag_client_test_result_uploaded_dict[client_index] is not True
        self.flag_client_test_result_uploaded_dict[client_index] = True

    def check_whether_all_processes_receive(self):
        for idx in range(self.client_num):
            if not self.flag_client_model_uploaded_dict[idx]:
                return False
        for idx in range(self.client_num):
            self.flag_client_model_uploaded_dict[idx] = False
        return True

    def check_whether_all_clients_receive(self):
        for idx in range(self.client_num_in_total):
            if not self.flag_client_test_result_uploaded_dict[idx]:
                return False
        for idx in range(self.client_num_in_total):
            self.flag_client_test_result_uploaded_dict[idx] = False
        return True

    def aggregate(self):
        averaged_weights = self.__aggregate_weight()

        if self.args.stage == "personalized_search":
            self.model.load_state_dict(averaged_weights, strict=False)
        else:
            self.model.load_state_dict(averaged_weights)

        if self.args.stage == "fednas_search" or self.args.stage == "backbone_cell_search":
            logging.info("Alpha aggregation called")
            averaged_alphas = self.__aggregate_alpha()
            self.__update_arch(averaged_alphas)
            return averaged_weights, averaged_alphas
        else:
            return averaged_weights

    def __update_arch(self, alphas):
        logging.info("update_arch. server.")
        for a_g, model_arch in zip(alphas, self.model.arch_parameters()):
            model_arch.data.copy_(a_g.data)

    def __aggregate_weight(self):
        logging.info("################aggregate weights############")
        start_time = time.time()
        training_num = 0
        model_list = []
        for idx in range(self.client_num):
            model_list.append((self.sample_num_dict[idx], self.model_dict[idx]))
            training_num += self.sample_num_dict[idx]

        (num0, averaged_params) = model_list[0]
        for k in averaged_params.keys():
            for i in range(0, len(model_list)):
                local_sample_number, local_model_params = model_list[i]
                w = local_sample_number / training_num
                # logging.info("w = %f, local_sample_number = %d, self.all_train_data_num = %d" %
                #              (w, local_sample_number, training_num))
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w
        model_list.clear()
        del model_list
        self.model_dict.clear()
        end_time = time.time()
        logging.info("aggregate weights time cost: %d" % (end_time - start_time))
        return averaged_params

    def __aggregate_alpha(self):
        logging.info("################aggregate alphas############")
        start_time = time.time()
        training_num = 0
        alpha_list = []
        for idx in range(self.client_num):
            # logging.info(self.arch_dict[idx])
            alpha_list.append((self.sample_num_dict[idx], self.arch_dict[idx]))
            training_num += self.sample_num_dict[idx]
        (num0, averaged_alphas) = alpha_list[0]
        for index, alpha in enumerate(averaged_alphas):
            for i in range(0, len(alpha_list)):
                local_sample_number, local_alpha = alpha_list[i]

                w = local_sample_number / training_num
                if i == 0:
                    alpha = local_alpha[index] * w
                else:
                    alpha += local_alpha[index] * w
        end_time = time.time()

        logging.info("aggregate alphas time cost: %d" % (end_time - start_time))
        return averaged_alphas

    def evaluation(self, round_idx):
        self._local_statistics(round_idx)


    def _local_statistics(self, round_idx):
        # Average validation accuracy (personalized) of this round only
        val_acc_list_per_round = self.test_acc_dict.values()
        logging.info("Client accuracies of this round "+str(val_acc_list_per_round))
        avg_valid_acc_this_round = sum(val_acc_list_per_round) / len(val_acc_list_per_round)
        logging.info('Round {:3d}, Averaged Validation Accuracy on Clients of this round {:.3f}'.format(round_idx,
                                                                                                           avg_valid_acc_this_round))
        wandb.log(
            {"Averaged Validation Accuracy (clients of this round) ": avg_valid_acc_this_round,
             "Round": round_idx})
        """
        save the Client average accuracy list for all clients 
        """
        # averaged validation accuracy (personalized) across all clients
        val_acc_list = self.acc_local_test_data_with_personalized_model_dict.values()
        logging.info("client accuracies "+str(self.acc_local_test_data_with_personalized_model_dict))
        self.acc_avg_personalized_model_on_all_clients = sum(val_acc_list) / len(val_acc_list)
        logging.info('Round {:3d}, Averaged Validation Accuracy on All Clients {:.3f}'.format(round_idx,
                                                                                                           self.acc_avg_personalized_model_on_all_clients))
        wandb.log(
            {"Averaged Validation Accuracy (All Clients)": self.acc_avg_personalized_model_on_all_clients,
             "Round": round_idx})
        """
        save the accuracy list for later comparison with other algorithms
        """
        if self.acc_avg_personalized_model_on_all_clients >= self.best_acc_prev_round:
            self.best_acc_prev_round = self.acc_avg_personalized_model_on_all_clients
            wandb_table_p = wandb.Table(columns=["Accuracy"])
            wandb_table_p.add_data(str(self.acc_local_test_data_with_personalized_model_dict))
            wandb.log({"Validation Accuracy (Local Adaptation (Step 2) (4 clients per round))": wandb_table_p})
        wandb_table_pe = wandb.Table(columns=["Accuracy"])
        wandb_table_pe.add_data(str(self.acc_local_test_data_with_personalized_model_dict))
        wandb.log({"Validation Accuracy (Personalized Model) (4 clients per round, current round)": wandb_table_pe})

    def _global_statistics(self, round_idx):
        """
        for global model
        """
        # train acc
        train_acc_list = self.train_acc_dict.values()
        self.acc_avg_train_global = sum(train_acc_list) / len(train_acc_list)
        logging.info(
            'Round {:3d}, Global Model - Average Train Accuracy {:.3f}'.format(round_idx, self.acc_avg_train_global))
        wandb.log({"Global Model - Train Accuracy": self.acc_avg_train_global, "Round": round_idx})

        # test acc
        logging.info('Round {:3d}, Global Model - Average Validation Accuracy {:.3f}'.format(round_idx,
                                                                                             self.acc_avg_test_global))
        wandb.log({"Global Model - Validation Accuracy": self.acc_avg_test_global, "Round": round_idx})


    def _test_on_local_data(self, client_idx, round_idx, test_data):
        self.model.eval()
        self.model.to(self.device)

        test_correct = 0.0
        test_loss = 0.0
        test_sample_number = 0.0

        # loss
        criterion = nn.CrossEntropyLoss().to(self.device)

        iteration_num = 0
        logging.info("Local Infer Starting")
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                iteration_num += 1
                x = x.to(self.device)
                target = target.to(self.device)
                pred = self.model(x)

                loss = criterion(pred, target)
                _, predicted = torch.max(pred, 1)
                correct = predicted.eq(target).sum()

                test_correct += correct.item()
                test_loss += loss.item() * target.size(0)
                test_sample_number += target.size(0)
                if iteration_num == 2 and self.args.is_debug_mode:
                    break
                logging.info("server test. client_idx = %d, round_idx = %d, batch_idx = %d/%d, correct = %d/%d, test_loss_avg = %s"
                             % (client_idx, round_idx, batch_idx, len(test_data), int(test_correct), int(test_sample_number),
                                test_loss / test_sample_number))

        acc = (test_correct / test_sample_number)
        loss = test_loss / test_sample_number
        return acc, loss

    def _infer(self, round_idx):
        self.model.eval()
        self.model.to(self.device)
        if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1:
            start_time = time.time()
            test_correct = 0.0
            test_loss = 0.0
            test_sample_number = 0.0
            test_data = self.test_global
            # loss
            criterion = nn.CrossEntropyLoss().to(self.device)

            iteration_num = 0
            logging.info("Global Infer Starting")
            with torch.no_grad():
                for batch_idx, (x, target) in enumerate(test_data):
                    iteration_num += 1
                    x = x.to(self.device)
                    target = target.to(self.device)
                    pred = self.model(x)

                    loss = criterion(pred, target)
                    _, predicted = torch.max(pred, 1)
                    correct = predicted.eq(target).sum()

                    test_correct += correct.item()
                    test_loss += loss.item() * target.size(0)
                    test_sample_number += target.size(0)
                    if iteration_num == 2 and self.args.is_debug_mode:
                        break
                    logging.info("server test. round_idx = %d, batch_idx = %d/%d, correct = %d/%d, test_loss_avg = %s"
                                 % (round_idx, batch_idx, len(test_data), test_correct, test_sample_number,
                                    test_loss / test_sample_number))

            self.acc_avg_test_global = (test_correct / test_sample_number)
            self.loss_avg_test_global = test_loss / test_sample_number
            logging.info("Validation Accuracy "+str(self.acc_avg_test_global))
            end_time = time.time()
            logging.info("server_infer time cost: %d" % (end_time - start_time))

    def _record_model_global_architecture(self, round_idx):

        # save the structure
        genotype, normal_cnn_count, cnn_structure_count_reduce = self.model.genotype()
        cnn_count = normal_cnn_count
        wandb.log({"cnn_count": cnn_count, "Round": round_idx})

        logging.info("(n:%d)" % (normal_cnn_count))
        logging.info('genotype = %s', genotype)
        wandb.log({"genotype": str(genotype), "round_idx": round_idx})

        wandb_table = wandb.Table(columns=["Epoch", "Searched Architecture"])
        wandb_table.add_data(str(round_idx), str(genotype))
        wandb.log({"Searched Architecture": wandb_table})

        # save the cnn architecture according to the CNN count
        cnn_count = normal_cnn_count * 10
        wandb.log({"searching_cnn_count(%s)" % cnn_count: self.acc_avg_test_global, "epoch": round_idx})
        if cnn_count not in self.best_accuracy_different_cnn_counts.keys():
            self.best_accuracy_different_cnn_counts[cnn_count] = self.acc_avg_test_global
            summary_key_cnn_structure = "best_acc_for_cnn_structure(n:%d)" % (
                normal_cnn_count)
            wandb.run.summary[summary_key_cnn_structure] = self.acc_avg_test_global

            summary_key_best_cnn_structure = "epoch_of_best_acc_for_cnn_structure(n:%d)" % (
                normal_cnn_count)
            wandb.run.summary[summary_key_best_cnn_structure] = round_idx
        else:
            if self.acc_avg_test_global > self.best_accuracy_different_cnn_counts[cnn_count]:
                self.best_accuracy_different_cnn_counts[cnn_count] = self.acc_avg_test_global
                summary_key_cnn_structure = "best_acc_for_cnn_structure(n:%d)" % (
                    normal_cnn_count)
                wandb.run.summary[summary_key_cnn_structure] = self.acc_avg_test_global

                summary_key_best_cnn_structure = "epoch_of_best_acc_for_cnn_structure(n:%d)" % (
                    normal_cnn_count)
                wandb.run.summary[summary_key_best_cnn_structure] = round_idx

        if self.acc_avg_test_global > self.acc_avg_test_global_best:
            self.acc_avg_test_global_best = self.acc_avg_test_global
            wandb.log({"global model - best_valid_accuracy": self.acc_avg_test_global_best, "Round": round_idx})
            wandb.run.summary["global model - best_valid_accuracy"] = self.acc_avg_test_global_best
            wandb.run.summary["global model - round_of_best_accuracy"] = round_idx
            logging.info("save model to %s" % self.args.path_of_best_global_model)
            # save weights
            torch.save(self.model.cpu().state_dict(), self.args.path_of_best_global_model)
            # save alpha
            torch.save(self.model.cpu().arch_parameters(), self.args.path_of_best_global_arch_parameter)
        logging.info("finished record_model_global_architecture()")

    def save_best_global_model(self):
        wandb.save(self.args.path_of_best_global_model)

    def personal_evaluation(self, round_idx):
        logging.info("personal_evaluation")
        logging.info(self.acc_local_test_data_with_personalized_model_dict)
        logging.info(self.acc_local_test_data_with_global_model_dict)
        wandb_table_pe = wandb.Table(columns=["(Fine-Tunned) Global Model After step 2"])
        wandb_table_pe.add_data(str(self.acc_local_test_data_with_global_model_dict))
        wandb.log({"(Fine-Tunned) Global Model After step 2": wandb_table_pe})

        """
        save the accuracy list for later comparison with other algorithms
        """
        wandb_table_p2 = wandb.Table(columns=["(Fine-Tunned FedNAS) (Step2)"])
        wandb_table_p2.add_data(str(self.acc_local_test_data_with_personalized_model_dict))
        wandb.log({"(Fine-Tunned FedNAS) (Step2)": wandb_table_p2})


        # averaged validation accuracy (personalized) across all clients
        val_acc_list = self.acc_local_test_data_with_personalized_model_dict.values()
        self.acc_avg_personalized_model_on_all_clients = sum(val_acc_list) / len(val_acc_list)

        # averaged validation accuracy (global) across all clients
        val_acc_list = self.acc_local_test_data_with_global_model_dict.values()
        self.acc_avg_global_model_on_all_clients = sum(val_acc_list) / len(val_acc_list)

        logging.info('Round {:3d}, Averaged Validation Personalized Accuracy on All Clients {:.3f}'.format(round_idx,
                                                                                                           self.acc_avg_personalized_model_on_all_clients))
        wandb.log(
            {"Averaged Validation Accuracy on All Clients (Step 2)": self.acc_avg_personalized_model_on_all_clients,
             "Round": round_idx})


        logging.info('Round {:3d}, Averaged Validation Global Accuracy on All Clients (Step 2) {:.3f}'.format(round_idx,
                                                                                                     self.acc_avg_global_model_on_all_clients))
        wandb.log({"Averaged Validation Global Accuracy on All Clients (Step 2)": self.acc_avg_global_model_on_all_clients,
                   "Round": round_idx})

        """
        show accuracy of each client (x: client_idx; y: accuracy)
        """
        diff = [a_i - b_i for a_i, b_i in
                zip(self.acc_local_test_data_with_personalized_model_dict.values(),
                    self.acc_local_test_data_with_global_model_dict.values())]
        acc_list_personalized = self.acc_local_test_data_with_personalized_model_dict.values()
        acc_list_global = self.acc_local_test_data_with_global_model_dict.values()

        label_p = 'Personalized Model' + " (Mean = " + str(
            round(statistics.mean(acc_list_personalized), 3)) + ', STD = ' + str(
            round(statistics.stdev(acc_list_personalized), 3)) + ")"

        label_g = 'Global Model' + " (Mean = " + str(
            round(statistics.mean(acc_list_global), 3)) + ', STD = ' + str(
            round(statistics.stdev(acc_list_global), 3)) + ")"

        #############################
        plt.figure()
        plt.bar(list(range(self.client_num_in_total)), acc_list_personalized, alpha=0.5, label=label_p)
        plt.bar(list(range(self.client_num_in_total)), acc_list_global, alpha=0.5, label=label_g)
        plt.legend(prop={'size': 10})
        plt.ylabel('Validation Accuracy (Personalized Model)')
        plt.xlabel('Client Index')
        plt.title("Validation Accuracy (x: client index; y: accuracy) (Step 2)")
        wandb.log({"Validation Accuracy (x: client index; y - accuracy) (Step 2)": [wandb.Image(plt)]})

        #############################
        plt.figure()
        plt.bar(list(range(self.client_num_in_total)), sorted(diff), color='g')
        plt.ylabel('Validation Accuracy Gap')
        plt.xlabel('Client Index')
        plt.title("Validation Accuracy Improvement (x: client index; y - accuracy)")
        wandb.log({"Validation Accuracy Improvement (x: client index; y - accuracy) (Step 2)": [wandb.Image(plt)]})

        """
        distribution of accuracy (x: accuracy; y: client number)
        """
        #############################
        plt.figure()
        plt.hist(acc_list_personalized, bins=20, alpha=0.5, label=label_p)
        plt.hist(acc_list_global, bins=20, alpha=0.5, label=label_g)
        plt.legend(prop={'size': 10})
        plt.title("Accuracy Distribution (x: accuracy; y: client number)")
        wandb.log({"Accuracy Distribution (x: accuracy; y: client number) (Step 2)": [wandb.Image(plt)]})

        #############################
        plt.figure()
        plt.hist(diff,
                 bins=[-1.00, -.90, -.80, -.70, -.60, -.50, -.40, -.30, -.20, -.10, 0, .10, .20, .30, .40, .50, .60,
                       .70, .80, .90, 1.00], color='g')
        plt.ylabel('Clients Numbers')
        plt.xlabel('Validation Accuracy Improvement')
        plt.text(1, 1, " Mean = " + str(statistics.mean(diff)))
        plt.text(1, 3, ' STD = ' + str(statistics.stdev(diff)))
        plt.title("Accuracy Improvement Distribution (x: accuracy; y: client number)")
        wandb.log({"Accuracy Improvement Distribution (x: accuracy; y: client number) (Step 2)": [wandb.Image(plt)]})
