import logging
import math
from copy import deepcopy

import torch
from torch import nn

from fedml_api.distributed.fedper.perFedAvgSSLOpt import PerFedAvgSSLOptimizer

try:
    from fedml_core.trainer.model_trainer import ModelTrainer
except ImportError:
    from FedML.fedml_core.trainer.model_trainer import ModelTrainer


class SSLperFedAvgTrainer(ModelTrainer):
    def __init__(self, model, args=None, device=None):
        super().__init__(model, args)

        self.model = model

        self.device = device
        self.client_index = -1

        kwargs = dict()
        kwargs['lr'] = args.lr
        kwargs['momentum'] = args.momentum
        kwargs['wd'] = args.wd
        kwargs['lambda'] = args.pssl_lambda
        kwargs['client_optimizer'] = args.client_optimizer
        kwargs['accumulation_steps'] = args.accumulation_steps
        kwargs['is_first_order'] = False

        kwargs['global_lr'] = args.lr
        kwargs['local_lr'] = args.lr * 10  # bigger

        self.pssl_optimizer = PerFedAvgSSLOptimizer(self.device, self.model, kwargs=kwargs)

        self.averaged_loss = 0.0

    def get_model_params(self):
        return self.model.cpu().state_dict()  # send global model

    def set_model_params(self, model_parameters):
        self.model.load_state_dict(model_parameters)  # update only global

    def get_global_model(self):
        return self.model

    def get_personalized_model(self):
        return self.personalized_model

    def train(self, train_data, client_idx, device, args):
        self.client_index = client_idx
        self.adjust_learning_rate(self.pssl_optimizer.optimizer, args.lr, self.args.round_idx, args.comm_round)
        self.train_personal_model(train_data, device, args)

    def get_averaged_loss(self):
        return self.averaged_loss

    def update_client_index(self, client_index):
        self.client_index = client_index

    def adjust_learning_rate(self, optimizer, initial_lr, round_index, total_round):
        """Decay the learning rate based on schedule"""
        lr = initial_lr
        # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * round_index / total_round))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train_personal_model(self, train_data, device, args):
        self.model = self.model.to(device)
        self.model.train()

        epoch_loss = []
        for epoch in range(args.personal_local_epochs):
            batch_loss = []
            for batch_idx, (x, labels) in enumerate(train_data):
                x = x.to(device)
                labels = labels.to(device)
                local_loss = self.pssl_optimizer.step(batch_idx, (x, labels))
                batch_loss.append(local_loss.item())
                logging.info('(Trainer_ID {}. Epoch: {}. Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                                     epoch,
                                                                                                     batch_idx,
                                                                                                     len(train_data),
                                                                                                     local_loss.item()))
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        self.averaged_loss = sum(epoch_loss) / len(epoch_loss)

    def finetune_personal_model(self, train_data, device):
        p_model = deepcopy(self.model).to(device)
        p_model.train()

        opt = torch.optim.SGD(p_model.parameters(), lr=self.pssl_optimizer.local_lr,
                              momentum=self.pssl_optimizer.momentum,
                              weight_decay=self.pssl_optimizer.wd)
        criterion = nn.CrossEntropyLoss().to(device)

        for batch_idx, (x, labels) in enumerate(train_data):
            x = x.to(device)
            labels = labels.to(device)
            logits = p_model(x)
            loss = criterion(logits, labels)
            loss.backward()

            logging.info('(Trainer_ID {}. Finetune Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                               batch_idx,
                                                                                               self.pssl_optimizer.accumulation_steps,
                                                                                               loss.item()))

            if (batch_idx + 1) % self.pssl_optimizer.accumulation_steps == 0:
                opt.step()
                opt.zero_grad()
                break

        self.personalized_model = p_model


    def test(self, test_data, device, args):
        model = self.model

        model.eval()
        model.to(device)

        metrics = {
            'test_correct': 0,
            'test_loss': 0,
            'test_precision': 0,
            'test_recall': 0,
            'test_total': 0
        }

        criterion = nn.CrossEntropyLoss().to(device)
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(device)
                target = target.to(device)
                pred = model(x)
                loss = criterion(pred, target)
                if args.dataset == "stackoverflow_lr":
                    predicted = (pred > .5).int()
                    correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum()
                    true_positive = ((target * predicted) > .1).int().sum(axis=-1)
                    precision = true_positive / (predicted.sum(axis=-1) + 1e-13)
                    recall = true_positive / (target.sum(axis=-1) + 1e-13)
                    metrics['test_precision'] += precision.sum().item()
                    metrics['test_recall'] += recall.sum().item()
                else:
                    _, predicted = torch.max(pred, -1)
                    correct = predicted.eq(target).sum()

                metrics['test_correct'] += correct.item()
                metrics['test_loss'] += loss.item() * target.size(0)
                metrics['test_total'] += target.size(0)

            logging.info("testing...")
        return metrics

    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        pass
