# https://proceedings.neurips.cc/paper_files/paper/2020/file/b704ea2c39778f07c617f6b7ce480e9e-Paper.pdf

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Buffer import Buffer
from .base import *
from .utils_model import backbone


class DER(ContinualLearning):
    # https://github.com/aimagelab/mammoth/blob/master/models/der.py
    def __init__(
        self,
        encoder: nn.Module,
        lr=0.001,
        temperature: float = 0.07,
        lambda_: float = 0.1,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        buffer_size: int = 2000,
        input_size: int = 64,
        z_dim: int = 512,
        dataset_name: str = "celeba",
        device="cuda",
        **kwargs
    ) -> None:
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(DER, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.lambda_ = lambda_
        self.buffer = Buffer(
            capacity=buffer_size,
            input_size=input_size,
            total_logits_dim=cls_output_dim * num_tasks,
            device=device,
            dataset_name=dataset_name,
        )
        self.cls_output_dim = cls_output_dim

        self.temperature = temperature

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return z

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:

        self.optimizer.zero_grad()
        tot_loss = 0
        outputs = self.forward(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        loss = loss_func(outputs_sliced, labels)
        loss.backward()
        tot_loss += loss.item()
        batch_size = inputs.size(0)
        if not self.buffer.is_empty():
            buf_inputs, buf_logits, _, _ = self.buffer.get_data(
                batch_size, transform=transform
            )
            buf_outputs = self.forward(buf_inputs)
            loss_mse = self.lambda_ * F.mse_loss(buf_outputs, buf_logits)
            loss_mse.backward()
            tot_loss += loss_mse.item()

        self.optimizer.step()
        self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)
        return tot_loss
