import copy
import logging
import math

import torch
import torch.nn as nn

from fedml_api.distributed.fedper.DittoOpt import DittoSSLOptimizer
from fedml_api.distributed.fedper.utils import load_personal_model, save_personal_model

try:
    from fedml_core.trainer.model_trainer import ModelTrainer
except ImportError:
    from FedML.fedml_core.trainer.model_trainer import ModelTrainer


class SSLDittoTrainer(ModelTrainer):
    def __init__(self, model, args=None, device=None):
        super().__init__(model, args)

        self.global_model = model
        self.local_model = copy.deepcopy(model)
        self.model = self.global_model

        self.device = device
        self.client_index = -1

        kwargs = dict()
        kwargs['lr'] = args.lr
        kwargs['momentum'] = args.momentum
        kwargs['wd'] = args.wd
        kwargs['lambda'] = args.pssl_lambda
        kwargs['client_optimizer'] = args.client_optimizer
        kwargs['accumulation_steps'] = args.accumulation_steps
        kwargs['is_first_order'] = False

        self.per_optimizer = DittoSSLOptimizer(self.device, self.local_model, self.global_model, kwargs=kwargs)

        self.averaged_loss = 0.0

    def get_model_params(self):
        return self.model.cpu().state_dict()  # send global model

    def set_model_params(self, model_parameters):
        self.model.load_state_dict(model_parameters)  # update only global

    def get_global_model(self):
        return self.global_model

    def get_personalized_model(self):
        return self.local_model

    def train(self, train_data, client_idx, device, args):
        # cross-round learning rate scheduler
        self.client_index = client_idx
        self.adjust_learning_rate(self.per_optimizer.g_optimizer, args.lr, self.args.round_idx, args.comm_round)
        self.adjust_learning_rate(self.per_optimizer.l_optimizer, args.lr, self.args.round_idx, args.comm_round)
        load_personal_model(self.args, self.local_model, self.client_index)
        self.train_personal_model(train_data, device, args)
        save_personal_model(self.args, self.local_model, self.client_index)

    def get_averaged_loss(self):
        return self.averaged_loss

    def adjust_learning_rate(self, optimizer, initial_lr, round_index, total_round):
        """Decay the learning rate based on schedule"""
        lr = initial_lr
        # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * round_index / total_round))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train_personal_model(self, train_data, device, args):
        self.local_model = self.local_model.to(device)
        self.global_model = self.global_model.to(device)
        self.local_model.train()
        self.global_model.train()

        epoch_loss = []
        for epoch in range(args.personal_local_epochs):
            batch_loss = []
            x_accumulator = []
            for batch_idx, (x, labels) in enumerate(train_data):
                x_accumulator.append((x, labels))
                if (batch_idx + 1) % args.accumulation_steps:
                    local_loss = self.per_optimizer.step(x_accumulator)
                    x_accumulator.clear()
                    batch_loss.append(local_loss.item())
                    logging.info('(Trainer_ID {}. Epoch: {}. Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                                         epoch,
                                                                                                         batch_idx,
                                                                                                         len(
                                                                                                             train_data),
                                                                                                         local_loss.item()))

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        self.averaged_loss = sum(epoch_loss) / len(epoch_loss)


    def test(self, test_data, device, args):
        model = self.model

        model.eval()
        model.to(device)

        metrics = {
            'test_correct': 0,
            'test_loss': 0,
            'test_precision': 0,
            'test_recall': 0,
            'test_total': 0
        }

        criterion = nn.CrossEntropyLoss().to(device)
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(device)
                target = target.to(device)
                pred = model(x)
                loss = criterion(pred, target)
                if args.dataset == "stackoverflow_lr":
                    predicted = (pred > .5).int()
                    correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum()
                    true_positive = ((target * predicted) > .1).int().sum(axis=-1)
                    precision = true_positive / (predicted.sum(axis=-1) + 1e-13)
                    recall = true_positive / (target.sum(axis=-1) + 1e-13)
                    metrics['test_precision'] += precision.sum().item()
                    metrics['test_recall'] += recall.sum().item()
                else:
                    _, predicted = torch.max(pred, -1)
                    correct = predicted.eq(target).sum()

                metrics['test_correct'] += correct.item()
                metrics['test_loss'] += loss.item() * target.size(0)
                metrics['test_total'] += target.size(0)

            logging.info("testing...")
        return metrics

    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        pass
