import copy
import logging
import math

import torch
from torch import nn

from fedml_api.distributed.fedssl.sup_ssl_opt.DittoSupOpt import DittoSupOptimizer
from fedml_api.distributed.fedssl.utils import get_global_model_path, \
    get_personalized_model_path
from fedml_api.model.cv.ssl import resnet18_cifar, meta_resnet18

try:
    from fedml_core.trainer.model_trainer import ModelTrainer
except ImportError:
    from FedML.fedml_core.trainer.model_trainer import ModelTrainer


class SSLEvaluationTrainer(ModelTrainer):
    def __init__(self, model, args=None, device=None):
        super().__init__(model, args)
        self.device = device
        self.client_index = 0

        self.averaged_loss = 0.0
        self.personalized_model = None
        self.is_backbone_loaded = False

        self.global_model = None
        self.local_model = None
        self.model = None

    def get_model_params(self):
        if not self.is_backbone_loaded:
            logging.info("client_index = %d" % self.client_index)
            self.model = self._load_ssl_pretrained_model(self.args, self.client_index)
            self.global_model = self.model
            self.local_model = copy.deepcopy(self.model)
            self.is_backbone_loaded = True
        return self.global_model.fc.cpu().state_dict()  # send global model

    def set_model_params(self, model_parameters):
        if not self.is_backbone_loaded:
            logging.info("client_index = %d" % self.client_index)
            self.model = self._load_ssl_pretrained_model(self.args, self.client_index)
            self.global_model = self.model
            self.local_model = copy.deepcopy(self.model)
            self.is_backbone_loaded = True
        self.global_model.fc.load_state_dict(model_parameters)  # update only global

    def get_global_model(self):
        return self.global_model

    def get_personalized_model(self):
        return self.local_model

    def get_averaged_loss(self):
        return self.averaged_loss

    def update_index(self, client_index):
        self.client_index = client_index

    def adjust_learning_rate(self, optimizer, initial_lr, round_index, total_round):
        """Decay the learning rate based on schedule"""
        lr = initial_lr
        # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * round_index / total_round))

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def get_backbone(self, backbone_name, num_cls=10):
        if self.args.pssl_optimizer == "perFedAvg":
            models = {'resnet18_cifar': meta_resnet18()}
        else:
            models = {'resnet18_cifar': resnet18_cifar()}
        return models[backbone_name]

    def train(self, train_data, client_idx, device, args):
        if args.pssl_optimizer == "FedAvg":
            self._train_with_local_sgd(train_data, client_idx, device, args)
        else:
            self._train_with_per_sup_Ditto(train_data, client_idx, device, args)

    def _train_with_local_sgd(self, train_data, client_idx, device, args):
        self.client_index = client_idx

        model = self.model
        model.to(device)
        model.train()
        criterion = nn.CrossEntropyLoss().to(device)
        parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        assert len(parameters) == 2  # fc.weight, fc.bias
        optimizer = torch.optim.SGD(parameters,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)
        self.adjust_learning_rate(optimizer, args.lr, args.round_idx, args.comm_round)
        epoch_loss = []
        for epoch in range(args.epochs):
            batch_loss = []
            for batch_idx, (x, labels) in enumerate(train_data):
                x, labels = x.to(device), labels.to(device)
                optimizer.zero_grad()
                log_probs = model(x)
                loss = criterion(log_probs, labels)
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())
            if len(batch_loss) > 0:
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                logging.info('(Trainer_ID {}. Local Training Epoch: {} \tLoss: {:.6f}'.format(self.id,
                                                                                              epoch,
                                                                                              sum(epoch_loss) / len(
                                                                                                  epoch_loss)))

    def test(self, test_data, device, args):
        if self.args.pssl_optimizer == "FedAvg" or self.args.pssl_optimizer == "perFedAvg":
            model = self.global_model
        else:
            model = self.local_model

        model.eval()
        model.to(device)

        metrics = {
            'test_correct': 0,
            'test_loss': 0,
            'test_precision': 0,
            'test_recall': 0,
            'test_total': 0
        }

        criterion = nn.CrossEntropyLoss().to(device)
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(device)
                target = target.to(device)
                pred = model(x)
                loss = criterion(pred, target)
                if args.dataset == "stackoverflow_lr":
                    predicted = (pred > .5).int()
                    correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum()
                    true_positive = ((target * predicted) > .1).int().sum(axis=-1)
                    precision = true_positive / (predicted.sum(axis=-1) + 1e-13)
                    recall = true_positive / (target.sum(axis=-1) + 1e-13)
                    metrics['test_precision'] += precision.sum().item()
                    metrics['test_recall'] += recall.sum().item()
                else:
                    _, predicted = torch.max(pred, -1)
                    correct = predicted.eq(target).sum()

                metrics['test_correct'] += correct.item()
                metrics['test_loss'] += loss.item() * target.size(0)
                metrics['test_total'] += target.size(0)

            logging.info("testing...")
        return metrics

    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        pass

    def _load_ssl_pretrained_model(self, args, client_idx):
        model = self.get_backbone(args.model, args.num_class)
        # logging.info(model)
        # freeze all layers but the last fc
        for name, param in model.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                param.requires_grad = False
        # init the fc layer
        model.fc.weight.data.normal_(mean=0.0, std=0.01)
        model.fc.bias.data.zero_()
        # optimize only the linear classifier
        parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        assert len(parameters) == 2  # fc.weight, fc.bias

        if args.pssl_optimizer == "perFedAvg" or \
                args.pssl_optimizer == "FedAvg" or \
                args.pssl_optimizer == "FedAvg_LocalAdaptation":
            path = get_global_model_path(args)
            logging.info("=> loading checkpoint '{}'".format(path))
            state_dict = torch.load(path, map_location="cpu")['state_dict']
        else:
            path = get_personalized_model_path(args, client_idx)
            logging.info(path)
            logging.info("=> loading checkpoint '{}'".format(path))
            state_dict = torch.load(path, map_location="cpu")

        new_state_dict = dict()
        for old_key, value in state_dict.items():
            # logging.info(old_key)
            if old_key.startswith('model') and 'fc' not in old_key:
                new_key = old_key.replace('model.', '')
                new_state_dict[new_key] = value
        logging.info("***************")
        args.start_epoch = 0
        msg = model.load_state_dict(new_state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

        logging.info("client index = %d" % self.client_index)
        logging.info("=> loaded pre-trained model '{}'".format(path))
        return model

    def _train_with_per_sup_Ditto(self, train_data, client_idx, device, args):
        self.local_model = self.local_model.to(device)
        self.global_model = self.global_model.to(device)
        self.local_model.train()
        self.global_model.train()

        kwargs = dict()
        kwargs['lr'] = args.lr
        kwargs['momentum'] = args.momentum
        kwargs['wd'] = args.wd
        kwargs['lambda'] = args.pssl_lambda
        kwargs['client_optimizer'] = args.client_optimizer
        kwargs['accumulation_steps'] = args.accumulation_steps
        kwargs['is_first_order'] = False
        per_optimizer = DittoSupOptimizer(self.device, self.local_model, self.global_model, kwargs=kwargs)

        self.adjust_learning_rate(per_optimizer.g_optimizer, args.lr, self.args.round_idx, args.comm_round)
        self.adjust_learning_rate(per_optimizer.l_optimizer, args.lr, self.args.round_idx, args.comm_round)

        epoch_loss = []
        for epoch in range(args.personal_local_epochs):
            batch_loss = []
            x_accumulator = []
            for batch_idx, (x, labels) in enumerate(train_data):
                x_accumulator.append((x, labels))
                if (batch_idx + 1) % args.accumulation_steps:
                    local_loss = per_optimizer.step(x_accumulator)
                    x_accumulator.clear()
                    batch_loss.append(local_loss.item())
                    logging.info('(Trainer_ID {}. Epoch: {}. Batch Index = {}/{}. \tLoss: {:.6f}'.format(self.id,
                                                                                                         epoch,
                                                                                                         batch_idx,
                                                                                                         len(
                                                                                                             train_data),
                                                                                                         local_loss.item()))

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        self.averaged_loss = sum(epoch_loss) / len(epoch_loss)
