import logging

from .knn_evaluation import KNNValidation
from .utils import transform_tensor_to_list


class FedSSLTrainer(object):

    def __init__(self, client_index, train_data_local_dict, train_data_local_num_dict, test_data_local_dict,
                 train_data_num, device, args, model_trainer):
        self.trainer = model_trainer

        self.client_index = client_index
        self.train_data_local_dict = train_data_local_dict
        self.train_data_local_num_dict = train_data_local_num_dict
        self.test_data_local_dict = test_data_local_dict
        self.all_train_data_num = train_data_num
        (self.train_local, self.train_local_knn) = self.train_data_local_dict[client_index]
        self.local_sample_number = self.train_data_local_num_dict[client_index]
        (self.test_local, self.test_local_knn) = self.test_data_local_dict[client_index]

        self.device = device
        self.args = args

        self.knn_monitor = KNNValidation(self.args, val_dataloader=self.test_local_knn)
        self.knn_k_value = 1 if self.args.dataset == "cifar10" else 10

    def update_model(self, weights):
        self.trainer.set_model_params(weights)

    def update_dataset(self, client_index):
        logging.info("client_index = %d" % client_index)
        self.client_index = client_index
        self.trainer.update_index(client_index)
        (self.train_local, self.train_local_knn) = self.train_data_local_dict[client_index]
        self.local_sample_number = self.train_data_local_num_dict[client_index]
        (self.test_local, self.test_local_knn) = self.test_data_local_dict[client_index]

    def train(self, round_idx):
        self.args.round_idx = round_idx
        self.trainer.train(self.train_local, self.client_index, self.device, self.args)

        weights = self.trainer.get_model_params()
        averaged_loss = self.trainer.get_averaged_loss()

        # transform Tensor to list
        if self.args.is_mobile == 1:
            weights = transform_tensor_to_list(weights)
        return weights, self.local_sample_number, averaged_loss

    def test_global_model_on_local_data(self, round_idx):
        logging.info("test_personalized_model_on_local_data")
        self.knn_monitor.update_val_dataloader(self.test_local_knn)
        val_top1_acc = self.knn_monitor.eval(self.trainer.get_global_model().encoder, self.device, K=self.knn_k_value)
        return val_top1_acc

    def test_personalized_model_on_local_data(self, round_idx):
        if self.args.ssl_is_linear_eval == 1:
            test_metrics = self.trainer.test(self.test_local, self.device, self.args)
            test_tot_correct, test_num_sample, test_loss = test_metrics['test_correct'], \
                                                           test_metrics['test_total'], test_metrics['test_loss']
            val_top1_acc = test_tot_correct / test_num_sample
            logging.info("evaluation locally. personalized_accuracy = %f" % val_top1_acc)
            return val_top1_acc

        if self.args.pssl_optimizer == "perFedAvg":
            self.trainer.finetune_personal_model(self.train_local, self.device)
        elif self.args.pssl_optimizer == "FedAvg_LocalAdaptation":
            self.trainer.finetune_global_model(self.train_local, self.device)

        logging.info("test_personalized_model_on_local_data")
        self.knn_monitor.update_val_dataloader(self.test_local_knn)
        val_top1_acc = self.knn_monitor.eval(self.trainer.get_personalized_model().encoder, self.device, K=self.knn_k_value)
        return val_top1_acc

    def test(self, round_idx=None):
        # train data
        train_metrics = self.trainer.test(self.train_local, self.device, self.args)
        train_tot_correct, train_num_sample, train_loss = train_metrics['test_correct'], \
                                                          train_metrics['test_total'], train_metrics['test_loss']

        # test data
        test_metrics = self.trainer.test(self.test_local, self.device, self.args)
        test_tot_correct, test_num_sample, test_loss = test_metrics['test_correct'], \
                                                       test_metrics['test_total'], test_metrics['test_loss']

        return train_tot_correct, train_loss, train_num_sample, test_tot_correct, test_loss, test_num_sample
