import torch
from copy import deepcopy

from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.trainers.context import CtxVar
from federatedscope.gfl.loss.vat import VATLoss
from federatedscope.core.trainers import GeneralTorchTrainer


class FLITTrainer(GeneralTorchTrainer):
    def register_default_hooks_train(self):
        super(FLITTrainer, self).register_default_hooks_train()
        self.register_hook_in_train(new_hook=record_initialization_local,
                                    trigger='on_fit_start',
                                    insert_pos=-1)
        self.register_hook_in_train(new_hook=del_initialization_local,
                                    trigger='on_fit_end',
                                    insert_pos=-1)
        self.register_hook_in_train(new_hook=record_initialization_global,
                                    trigger='on_fit_start',
                                    insert_pos=-1)
        self.register_hook_in_train(new_hook=del_initialization_global,
                                    trigger='on_fit_end',
                                    insert_pos=-1)

    def register_default_hooks_eval(self):
        super(FLITTrainer, self).register_default_hooks_eval()
        self.register_hook_in_eval(new_hook=record_initialization_local,
                                   trigger='on_fit_start',
                                   insert_pos=-1)
        self.register_hook_in_eval(new_hook=del_initialization_local,
                                   trigger='on_fit_end',
                                   insert_pos=-1)
        self.register_hook_in_eval(new_hook=record_initialization_global,
                                   trigger='on_fit_start',
                                   insert_pos=-1)
        self.register_hook_in_eval(new_hook=del_initialization_global,
                                   trigger='on_fit_end',
                                   insert_pos=-1)

    def _hook_on_batch_forward(self, ctx):
        batch = ctx.data_batch.to(ctx.device)
        pred = ctx.model(batch)
        ctx.global_model.to(ctx.device)
        predG = ctx.global_model(batch)
        if ctx.criterion._get_name() == 'CrossEntropyLoss':
            label = batch.y.squeeze(-1).long()
        elif ctx.criterion._get_name() == 'MSELoss':
            label = batch.y.float()
        else:
            raise ValueError(
                f'FLIT trainer not support {ctx.criterion._get_name()}.')
        if len(label.size()) == 0:
            label = label.unsqueeze(0)

        lossGlobalLabel = ctx.criterion(predG, label)
        lossLocalLabel = ctx.criterion(pred, label)

        weightloss = lossLocalLabel + torch.relu(lossLocalLabel -
                                                 lossGlobalLabel.detach())
        if ctx.weight_denomaitor is None:
            ctx.weight_denomaitor = weightloss.mean(dim=0,
                                                    keepdim=True).detach()
        else:
            ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
                                    ctx.weight_denomaitor + (
                                            -self.cfg.flitplus.factor_ema +
                                            1) * weightloss.mean(
                                            keepdim=True, dim=0).detach()
        loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
                1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
        ctx.loss_batch = loss.mean()

        ctx.batch_size = len(label)
        ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
        ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)


class FLITPlusTrainer(FLITTrainer):
    def _hook_on_batch_forward(self, ctx):
        # LDS should be calculated before the forward for cross entropy
        batch = ctx.data_batch.to(ctx.device)
        ctx.global_model.to(ctx.device)
        if ctx.cur_mode == 'test':
            lossLocalVAT, lossGlobalVAT = torch.tensor(0.), torch.tensor(0.)
        else:
            vat_loss = VATLoss()  # xi, and eps
            lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
                                    deepcopy(ctx.criterion))
            lossGlobalVAT = vat_loss(deepcopy(ctx.global_model), batch,
                                     deepcopy(ctx.criterion))

        pred = ctx.model(batch)
        predG = ctx.global_model(batch)
        if ctx.criterion._get_name() == 'CrossEntropyLoss':
            label = batch.y.squeeze(-1).long()
        elif ctx.criterion._get_name() == 'MSELoss':
            label = batch.y.float()
        else:
            raise ValueError(
                f'FLITPLUS trainer not support {ctx.criterion._get_name()}.')
        if len(label.size()) == 0:
            label = label.unsqueeze(0)
        lossGlobalLabel = ctx.criterion(predG, label)
        lossLocalLabel = ctx.criterion(pred, label)

        weightloss_loss = lossLocalLabel + torch.relu(lossLocalLabel -
                                                      lossGlobalLabel.detach())
        weightloss_vat = (lossLocalVAT +
                          torch.relu(lossLocalVAT - lossGlobalVAT.detach()))
        weightloss = self.cfg.flitplus.lambdavat * \
            weightloss_vat + weightloss_loss
        if ctx.weight_denomaitor is None:
            ctx.weight_denomaitor = weightloss.mean(dim=0,
                                                    keepdim=True).detach()
        else:
            ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
                                    ctx.weight_denomaitor + (
                                            -self.cfg.flitplus.factor_ema +
                                            1) * weightloss.mean(
                                            keepdim=True, dim=0).detach()

        loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
                1e-7)**self.cfg.flitplus.tmpFed * (
                    lossLocalLabel +
                    self.cfg.flitplus.weightReg * lossLocalVAT)
        ctx.loss_batch = loss.mean()

        ctx.batch_size = len(label)
        ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
        ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)


class FedFocalTrainer(GeneralTorchTrainer):
    def register_default_hooks_train(self):
        super(FedFocalTrainer, self).register_default_hooks_train()
        self.register_hook_in_train(new_hook=record_initialization_local,
                                    trigger='on_fit_start',
                                    insert_pos=-1)
        self.register_hook_in_train(new_hook=del_initialization_local,
                                    trigger='on_fit_end',
                                    insert_pos=-1)

    def register_default_hooks_eval(self):
        super(FedFocalTrainer, self).register_default_hooks_eval()
        self.register_hook_in_eval(new_hook=record_initialization_local,
                                   trigger='on_fit_start',
                                   insert_pos=-1)
        self.register_hook_in_eval(new_hook=del_initialization_local,
                                   trigger='on_fit_end',
                                   insert_pos=-1)

    def _hook_on_batch_forward(self, ctx):
        batch = ctx.data_batch.to(ctx.device)
        pred = ctx.model(batch)
        if ctx.criterion._get_name() == 'CrossEntropyLoss':
            label = batch.y.squeeze(-1).long()
        elif ctx.criterion._get_name() == 'MSELoss':
            label = batch.y.float()
        else:
            raise ValueError(
                f'FLIT trainer not support {ctx.criterion._get_name()}.')
        if len(label.size()) == 0:
            label = label.unsqueeze(0)

        lossLocalLabel = ctx.criterion(pred, label)
        weightloss = lossLocalLabel
        if ctx.weight_denomaitor is None:
            ctx.weight_denomaitor = weightloss.mean(dim=0,
                                                    keepdim=True).detach()
        else:
            ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
                                    ctx.weight_denomaitor + (
                                            -self.cfg.flitplus.factor_ema +
                                            1) * weightloss.mean(
                                            keepdim=True, dim=0).detach()

        loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
                1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
        ctx.loss_batch = loss.mean()

        ctx.batch_size = len(label)
        ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
        ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)


class FedVATTrainer(GeneralTorchTrainer):
    def register_default_hooks_train(self):
        super(FedVATTrainer, self).register_default_hooks_train()
        self.register_hook_in_train(new_hook=record_initialization_local,
                                    trigger='on_fit_start',
                                    insert_pos=-1)
        self.register_hook_in_train(new_hook=del_initialization_local,
                                    trigger='on_fit_end',
                                    insert_pos=-1)

    def register_default_hooks_eval(self):
        super(FedVATTrainer, self).register_default_hooks_eval()
        self.register_hook_in_eval(new_hook=record_initialization_local,
                                   trigger='on_fit_start',
                                   insert_pos=-1)
        self.register_hook_in_eval(new_hook=del_initialization_local,
                                   trigger='on_fit_end',
                                   insert_pos=-1)

    def _hook_on_batch_forward(self, ctx):
        batch = ctx.data_batch.to(ctx.device)
        if ctx.cur_mode == 'test':
            lossLocalVAT = torch.tensor(0.)
        else:
            vat_loss = VATLoss()  # xi, and eps
            lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
                                    deepcopy(ctx.criterion))

        pred = ctx.model(batch)
        if ctx.criterion._get_name() == 'CrossEntropyLoss':
            label = batch.y.squeeze(-1).long()
        elif ctx.criterion._get_name() == 'MSELoss':
            label = batch.y.float()
        else:
            raise ValueError(
                f'FedVAT trainer not support {ctx.criterion._get_name()}.')
        if len(label.size()) == 0:
            label = label.unsqueeze(0)
        lossLocalLabel = ctx.criterion(pred, label)
        weightloss = lossLocalLabel + self.cfg.flitplus.lambdavat * \
            lossLocalVAT
        if ctx.weight_denomaitor is None:
            ctx.weight_denomaitor = weightloss.mean(dim=0,
                                                    keepdim=True).detach()
        else:
            ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
                                    ctx.weight_denomaitor + (
                                            -self.cfg.flitplus.factor_ema +
                                            1) * weightloss.mean(
                                            keepdim=True, dim=0).detach()

        loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
                1e-7)**self.cfg.flitplus.tmpFed * (
                    lossLocalLabel +
                    self.cfg.flitplus.weightReg * lossLocalVAT)
        ctx.loss_batch = loss.mean()

        ctx.batch_size = len(label)
        ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
        ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)


def record_initialization_local(ctx):
    """Record weight denomaitor to cpu

    """
    ctx.weight_denomaitor = None


def del_initialization_local(ctx):
    """Clear the variable to avoid memory leakage

    """
    ctx.weight_denomaitor = None


def record_initialization_global(ctx):
    """Record the shared global model to cpu

    """
    ctx.global_model = deepcopy(ctx.model)
    ctx.global_model.to(torch.device("cpu"))


def del_initialization_global(ctx):
    """Clear the variable to avoid memory leakage

    """
    ctx.global_model = None
