import copy
import logging
import math
import random
import time
from copy import deepcopy

import numpy as np
import torch
import wandb

from .knn_evaluation import KNNValidation
from .utils import transform_list_to_tensor, save_global_model


class FedSSLAggregator(object):

    def __init__(self, train_global, test_global, all_train_data_num,
                 train_data_local_dict, test_data_local_dict, train_data_local_num_dict, worker_num, device,
                 args, model_trainer):
        self.trainer = model_trainer

        self.args = args
        self.train_global = train_global
        self.test_global = test_global
        self.val_global = self._generate_validation_set()
        self.all_train_data_num = all_train_data_num

        self.train_data_local_dict = train_data_local_dict
        self.test_data_local_dict = test_data_local_dict
        self.train_data_local_num_dict = train_data_local_num_dict

        self.worker_num = worker_num
        self.device = device
        self.model_dict = dict()
        self.sample_num_dict = dict()
        self.averaged_loss_dict = dict()

        # for personalization
        self.per_acc_dict = dict()
        self.global_model_on_local_accuracy_dict = dict()

        self.flag_client_model_uploaded_dict = dict()
        for idx in range(self.worker_num):
            self.flag_client_model_uploaded_dict[idx] = False

        self.best_acc = 0.0

        self.knn_monitor = KNNValidation(self.args)
        self.knn_k_value = 1 if self.args.dataset == "cifar10" else 10

    def get_global_model_params(self):
        return self.trainer.get_model_params()

    def set_global_model_params(self, model_parameters):
        self.trainer.set_model_params(model_parameters)

    def add_local_trained_result(self, index, model_params, sample_num, averaged_loss,
                                 personalized_accuracy, global_model_on_local_accuracy):
        logging.info("add_model. index = %d" % index)
        self.model_dict[index] = model_params
        self.sample_num_dict[index] = sample_num
        self.averaged_loss_dict[index] = averaged_loss

        # for personalization
        self.per_acc_dict[index] = personalized_accuracy
        self.global_model_on_local_accuracy_dict[index] = global_model_on_local_accuracy

        self.flag_client_model_uploaded_dict[index] = True

    def check_whether_all_receive(self):
        for idx in range(self.worker_num):
            if not self.flag_client_model_uploaded_dict[idx]:
                return False
        for idx in range(self.worker_num):
            self.flag_client_model_uploaded_dict[idx] = False
        return True

    def aggregate(self):
        start_time = time.time()
        model_list = []
        training_num = 0

        for idx in range(self.worker_num):
            if self.args.is_mobile == 1:
                self.model_dict[idx] = transform_list_to_tensor(self.model_dict[idx])
            model_list.append((self.sample_num_dict[idx], self.model_dict[idx]))
            training_num += self.sample_num_dict[idx]

        logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict)))

        # logging.info("################aggregate: %d" % len(model_list))
        (num0, averaged_params) = model_list[0]
        for k in averaged_params.keys():
            for i in range(0, len(model_list)):
                local_sample_number, local_model_params = model_list[i]
                w = local_sample_number / training_num
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w

        # update the global model which is cached at the server side
        if self.args.pssl_optimizer == "pFedMe":
            if hasattr(self, 'shadow_params'):
                for k in averaged_params.keys():
                    averaged_params[k].data = (1 - self.args.pfedme_beta) * self.shadow_params[k].data \
                                              + self.args.pfedme_beta * averaged_params[k].data
            self.shadow_params = deepcopy(averaged_params)

        self.set_global_model_params(averaged_params)

        end_time = time.time()
        logging.info("aggregate time cost: %d" % (end_time - start_time))
        return averaged_params

    def client_sampling(self, round_idx, client_num_in_total, client_num_per_round):
        if client_num_in_total == client_num_per_round:
            client_indexes = [client_index for client_index in range(client_num_in_total)]
        else:
            num_clients = min(client_num_per_round, client_num_in_total)
            np.random.seed(round_idx)  # make sure for each comparison, we are selecting the same clients each round
            client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)
        logging.info("client_indexes = %s" % str(client_indexes))
        return client_indexes

    def _generate_validation_set(self, num_samples=10000):
        if self.args.dataset.startswith("stackoverflow"):
            test_data_num = len(self.test_global.dataset)
            sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num))
            subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices)
            sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size)
            return sample_testset
        else:
            return self.test_global

    def test_on_server_for_all_clients(self, round_idx):
        # lr
        lr = self.args.lr
        # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * round_idx / self.args.comm_round))
        wandb.log({"SSL-Train/Learning Rate": lr, "round": round_idx})

        # averaged training loss
        averaged_training_loss = 0.0
        for idx in range(self.worker_num):
            averaged_training_loss += self.averaged_loss_dict[idx]
        averaged_training_loss = averaged_training_loss / self.worker_num
        logging.info("averaged_training_loss = %f" % averaged_training_loss)
        wandb.log({"SSL-Train/Loss": averaged_training_loss, "round": round_idx})

        # averaged personalized accuracy
        if self.args.pssl_optimizer != "FedAvg" or self.args.ssl_is_linear_eval == 1:
            averaged_personalized_acc = 0.0
            for idx in range(self.worker_num):
                averaged_personalized_acc += self.per_acc_dict[idx]
            averaged_personalized_acc = averaged_personalized_acc / self.worker_num
            if averaged_personalized_acc > 0.0:
                logging.info("averaged_personalized_acc = %f" % averaged_personalized_acc)
                wandb.log({"averaged_personalized_acc": averaged_personalized_acc, "round": round_idx})

        else:
            global_model_on_local_accuracy = 0.0
            for idx in range(self.worker_num):
                global_model_on_local_accuracy += self.global_model_on_local_accuracy_dict[idx]
            global_model_on_local_averaged_accuracy = global_model_on_local_accuracy / self.worker_num
            if global_model_on_local_averaged_accuracy > 0.0:
                logging.info("averaged_personalized_acc = %f" % global_model_on_local_averaged_accuracy)
                wandb.log({"averaged_personalized_acc": global_model_on_local_averaged_accuracy, "round": round_idx})

        if self.args.ssl_is_linear_eval:
            return
        # use the global model as the indicator
        if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1:
            logging.info("start KNN evaluation...")
            # net, memory_data_loader, test_data_loader, k=200, t=0.1, hide_progress=False
            logging.info("5")
            val_top1_acc = self.knn_monitor.eval(self.trainer.get_global_model().encoder, self.device, K=self.knn_k_value)
            logging.info("Self-Supervised-Federated-Training/Accuracy = %f" % val_top1_acc)
            wandb.log({"Self-Supervised-Federated-Training/Accuracy": val_top1_acc, "round": round_idx})
            logging.info("finish KNN evaluation...")

            # save the best model
            if val_top1_acc > self.best_acc:
                self.best_acc = val_top1_acc

                # model_name, round, model, optimizer, acc
                save_global_model(self.args, self.trainer.get_global_model())
