import torch
from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.trainers.enums import LIFECYCLE, MODE
from federatedscope.core.trainers.utils import move_to


class CLTrainer(GeneralTorchTrainer):
    def __init__(self,
                 model,
                 data,
                 device,
                 config,
                 only_for_eval=False,
                 monitor=None):
        super(CLTrainer, self).__init__(model, data, device, config,
                                        only_for_eval, monitor)
        self.batches_aug_data_1, self.batches_aug_data_2 = torch.empty(
            1), torch.empty(1)
        self.z1, self.z2 = torch.empty(1), torch.empty(1)
        self.num_samples = 0
        self.local_loss_ratio = 1
        self.global_loss_ratio = 5

    def get_train_pred_embedding(self):
        model = self.ctx.model.to(self.ctx.device)
        x1, x2 = self.batches_aug_data_1.to(
            self.ctx.device), self.batches_aug_data_2.to(self.ctx.device)
        z1, z2 = model(x1, x2)
        self.batches_aug_data_1, self.batches_aug_data_2 = [], []
        self.z1, self.z2 = z1, z2
        self.ctx.model.to(torch.device('cpu'))

        return [self.z1, self.z2]

    def _hook_on_batch_forward(self, ctx):
        x, label = [move_to(_, ctx.device) for _ in ctx.data_batch]
        x1, x2 = x[0], x[1]
        if ctx.cur_mode in [MODE.TRAIN]:
            self.batches_aug_data_1 = x1
            self.batches_aug_data_2 = x2
        z1, z2 = ctx.model(x1, x2)
        if len(label.size()) == 0:
            label = label.unsqueeze(0)

        ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
        ctx.y_prob = CtxVar((z1, z2), LIFECYCLE.BATCH)
        ctx.loss_batch = CtxVar(ctx.criterion(z1, z2), LIFECYCLE.BATCH)
        ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)

    def _hook_on_batch_backward(self, ctx):
        ctx.optimizer.zero_grad()
        ctx.loss_task = ctx.loss_task * self.local_loss_ratio
        ctx.loss_task.backward()
        if ctx.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
                                           ctx.grad_clip)

        ctx.optimizer.step()
        if ctx.scheduler is not None:
            ctx.scheduler.step()

    def _hook_on_batch_end(self, ctx):
        # update statistics
        ctx.num_samples += ctx.batch_size
        if ctx.cur_mode in [MODE.TRAIN]:
            self.num_samples = ctx.num_samples
        ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size
        ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))
        # cache label for evaluate
        ctx.ys_true.append(ctx.y_true.detach().cpu().numpy())
        ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy())

    def train_with_global_loss(self, loss):

        self.ctx.model = self.ctx.model.to(self.ctx.device)

        loss = loss.requires_grad_() * self.global_loss_ratio
        loss.backward()

        self.ctx.optimizer.step()

        return self.ctx.model.state_dict()


class LPTrainer(GeneralTorchTrainer):
    def __init__(self,
                 model,
                 data,
                 device,
                 config,
                 only_for_eval=False,
                 monitor=None):
        super(LPTrainer, self).__init__(model, data, device, config,
                                        only_for_eval, monitor)

        if config.federate.restore_from != '':
            self.load_model(config.federate.restore_from)


def call_cl_trainer(trainer_type):
    if trainer_type == 'cltrainer':
        trainer_builder = CLTrainer
        return trainer_builder
    elif trainer_type == 'lptrainer':
        trainer_builder = LPTrainer
        return trainer_builder


register_trainer('cltrainer', call_cl_trainer)
