import logging
import math
from copy import deepcopy

import torch

from fedml_api.distributed.fedssl.opt.FedAvgSSLOpt import FedAvgSSLOptimizer

try:
    from fedml_core.trainer.model_trainer import ModelTrainer
except ImportError:
    from FedML.fedml_core.trainer.model_trainer import ModelTrainer


class SSLFedAvgTrainer(ModelTrainer):
    def __init__(self, model, args=None, device=None):
        super().__init__(model, args)
        self.device = device
        self.client_index = -1

        kwargs = dict()
        kwargs['lr'] = args.lr
        kwargs['momentum'] = args.momentum
        kwargs['wd'] = args.wd
        kwargs['lambda'] = args.pssl_lambda
        kwargs['client_optimizer'] = args.client_optimizer
        kwargs['accumulation_steps'] = args.accumulation_steps

        self.pssl_optimizer = FedAvgSSLOptimizer(self.device, self.model, kwargs=kwargs)

        self.averaged_loss = 0.0
        self.personalized_model = None

    def get_model_params(self):
        return self.model.cpu().state_dict()  # send global model

    def set_model_params(self, model_parameters):
        self.model.load_state_dict(model_parameters)  # update only global

    def get_global_model(self):
        return self.model

    def get_personalized_model(self):
        return self.personalized_model

    def update_index(self, client_index):
        self.client_index = client_index

    def train(self, train_data, client_idx, device, args):
        self.client_index = client_idx
        self.adjust_learning_rate(self.pssl_optimizer.optimizer, args.lr, self.args.round_idx, args.comm_round)
        self.train_global_model(train_data, device, args)

    def get_averaged_loss(self):
        return self.averaged_loss

    def adjust_learning_rate(self, optimizer, initial_lr, round_index, total_round):
        """Decay the learning rate based on schedule"""
        lr = initial_lr
        # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * round_index / total_round))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train_global_model(self, train_data, device, args):
        model = self.model
        model.to(device)
        model.train()
        epoch_loss = []
        for epoch in range(args.personal_local_epochs):
            batch_loss = []
            for batch_idx, ((x1, x2), labels) in enumerate(train_data):
                x1, x2 = x1.to(device), x2.to(device)
                local_loss = self.pssl_optimizer.step(batch_idx, x1, x2)
                batch_loss.append(local_loss.item())
                logging.info('(Trainer_ID {}. Epoch: {}. Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                                     epoch,
                                                                                                     batch_idx,
                                                                                                     len(train_data),
                                                                                                     local_loss.item()))
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        self.averaged_loss = sum(epoch_loss) / len(epoch_loss)

    def finetune_global_model(self, train_data, device):
        self.personalized_model = None
        p_model = deepcopy(self.model).to(device)
        p_model.train()

        opt = torch.optim.SGD(p_model.parameters(), lr=self.pssl_optimizer.lr,
                              momentum=self.pssl_optimizer.momentum,
                              weight_decay=self.pssl_optimizer.wd)

        for batch_idx, ((x1, x2), labels) in enumerate(train_data):
            x1, x2 = x1.to(device), x2.to(device)
            loss = p_model(x1, x2)
            loss.backward()

            logging.info('(Trainer_ID {}. Finetune Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                               batch_idx,
                                                                                               self.pssl_optimizer.accumulation_steps,
                                                                                               loss.item()))

            if (batch_idx + 1) % self.pssl_optimizer.accumulation_steps == 0:
                opt.step()
                opt.zero_grad()
                break
        self.personalized_model = p_model

    def test(self, test_data, device, args=None):
        pass

    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        pass
