import copy
import logging
import math

from fedml_api.distributed.fedssl.opt.DittoOpt import DittoSSLOptimizer
from fedml_api.distributed.fedssl.utils import load_personal_model, save_personal_model

try:
    from fedml_core.trainer.model_trainer import ModelTrainer
except ImportError:
    from FedML.fedml_core.trainer.model_trainer import ModelTrainer


class SSLDittoTrainer(ModelTrainer):
    def __init__(self, model, args=None, device=None):
        super().__init__(model, args)

        self.global_model = model
        self.local_model = copy.deepcopy(model)
        self.model = self.global_model

        self.device = device
        self.client_index = -1

        kwargs = dict()
        kwargs['lr'] = args.lr
        kwargs['momentum'] = args.momentum
        kwargs['wd'] = args.wd
        kwargs['lambda'] = args.pssl_lambda
        kwargs["ssl_method"] = args.ssl_method
        kwargs['client_optimizer'] = args.client_optimizer
        kwargs['accumulation_steps'] = args.accumulation_steps
        kwargs['is_first_order'] = args.perFedAvg_is_first_order

        self.pssl_optimizer = DittoSSLOptimizer(self.device, self.local_model, self.global_model, kwargs=kwargs)
        self.averaged_loss = 0.0

    def get_model_params(self):
        return self.model.cpu().state_dict()  # send global model

    def set_model_params(self, model_parameters):
        self.model.load_state_dict(model_parameters)  # update only global

    def get_global_model(self):
        return self.global_model

    def get_personalized_model(self):
        return self.local_model

    def update_index(self, client_index):
        self.client_index = client_index

    def train(self, train_data, client_idx, device, args):
        # cross-round learning rate scheduler
        self.client_index = client_idx
        self.adjust_learning_rate(self.pssl_optimizer.g_optimizer, args.lr, self.args.round_idx, args.comm_round)
        self.adjust_learning_rate(self.pssl_optimizer.l_optimizer, args.lr, self.args.round_idx, args.comm_round)
        load_personal_model(self.args, self.local_model, self.client_index)
        self.train_personal_model(train_data, device, args)
        save_personal_model(self.args, self.local_model, self.client_index)

    def get_averaged_loss(self):
        return self.averaged_loss

    def adjust_learning_rate(self, optimizer, initial_lr, round_index, total_round):
        """Decay the learning rate based on schedule"""
        lr = initial_lr
        # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * round_index / total_round))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train_personal_model(self, train_data, device, args):
        self.local_model = self.local_model.to(device)
        self.global_model = self.global_model.to(device)
        self.local_model.train()
        self.global_model.train()

        epoch_loss = []
        for epoch in range(args.personal_local_epochs):
            batch_loss = []
            x_accumulator = []
            for batch_idx, ((x1, x2), labels) in enumerate(train_data):
                x_accumulator.append((x1, x2))
                if (batch_idx + 1) % args.accumulation_steps:
                    local_loss = self.pssl_optimizer.step(x_accumulator)
                    x_accumulator.clear()
                    batch_loss.append(local_loss.item())
                    logging.info('(Trainer_ID {}. Epoch: {}. Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                                         epoch,
                                                                                                         batch_idx,
                                                                                                         len(
                                                                                                             train_data),
                                                                                                         local_loss.item()))

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        self.averaged_loss = sum(epoch_loss) / len(epoch_loss)

    def test(self, test_data, device, args=None):
        pass

    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        pass
