import torch.nn as nn

from models.utils_model import backbone, LogitsBatchIterator
from .LWP import LwP
from .base import *

import torch
from torch.optim import SGD
from models.base import ContinualLearning

# https://github.com/aimagelab/mammoth/blob/master/models/lwf.py

def smooth(logits, temp, dim):
    log = logits ** (1 / temp)
    return log / torch.sum(log, dim).unsqueeze(1)


def modified_kl_div(old, new):
    return -torch.mean(torch.sum(old * torch.log(new), 1))


class LwF(ContinualLearning):
    def __init__(
        self,
        encoder: nn.Module,
        lr=0.001,
        lambda_: float = 0.1,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        n_epochs=10,
        z_dim=512,
        device="cuda",
        **kwargs
    ):
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(LwF, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.soft = torch.nn.Softmax(dim=1)
        self.logsoft = torch.nn.LogSoftmax(dim=1)
        # self.optimizer = SGD(self.net.classifier.parameters(), lr=lr)
        self.n_epochs = n_epochs
        self.lr = lr
        self.device = device
        self.logits = None
        self.alpha = lambda_

    def begin_task(
        self,
        dataloader: torch.utils.data.DataLoader,
        task_name: str,
        task_id: int,
        criterion: torch.nn.modules.loss._Loss,
    ) -> None:
        self.encoder.eval()
        if task_id > 0:
            # warm-up
            # opt = SGD(self.encoder.linear.parameters(), lr=self.lr)
            # for epoch in range(self.n_epochs):
            #     for sample in dataloader:
            #         image = sample["image"].to(self.device)
            #         cur_task_y = (
            #             sample[task_name].type(torch.LongTensor).to(self.device)
            #         )
            #         self.optimizer.zero_grad()
            #         with torch.no_grad():
            #             feats = self.encoder(image, return_type="features")
            #         outputs = self.encoder.linear(feats)[
            #             :,
            #             task_id
            #             * self.cls_output_dim : task_id
            #             * self.cls_output_dim
            #             + self.cls_output_dim,
            #         ]
            #         loss = criterion(outputs, cur_task_y)
            #         loss.backward()
            #         self.optimizer.step()

            logits = []
            with torch.no_grad():
                for sample in dataloader:
                    image = sample["image"].to(self.device)
                    log = self.encoder(image.to(self.device)).cpu()
                    logits.append(log)
            # make logits as a iter that similarly to dataloader:
            self.logits = LogitsBatchIterator(torch.cat(logits), dataloader.batch_size)
        self.encoder.train()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return z

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:
        self.optimizer.zero_grad()
        outputs = self.forward(inputs)

        n_past_classes = task_id * self.cls_output_dim
        n_seen_classes = (task_id + 1) * self.cls_output_dim
        loss = loss_func(outputs[:, n_past_classes:n_seen_classes], labels)
        if self.logits is not None:
            loss += self.alpha * modified_kl_div(
                smooth(
                    self.soft(next(self.logits)[:, : n_past_classes]).to(
                        self.device), 2, 1,),
                smooth(self.soft(outputs[:, : n_past_classes]), 2, 1),
            )
        loss.backward()
        self.optimizer.step()
        return loss.item()
