# https://proceedings.neurips.cc/paper_files/paper/2020/file/b704ea2c39778f07c617f6b7ce480e9e-Paper.pdf

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Buffer import Buffer
from .base import *
from .utils_model import backbone


class DERPP(ContinualLearning):
    def __init__(
        self,
        encoder: nn.Module,
        lr: float = 0.001,
        alpha: float = 0.1,
        beta: float = 0.1,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        input_size: int = 64,
        buffer_size: int = 2000,
        z_dim: int = 512,
        dataset_name: str = "celeba",
        device: str = "cuda",
        **kwargs
    ) -> None:
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(DERPP, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.lambda_alpha = alpha
        self.lambda_beta = beta
        self.buffer = Buffer(
            capacity=buffer_size,
            device=device,
            input_size=input_size,
            total_logits_dim=cls_output_dim * num_tasks,
            dataset_name=dataset_name,
        )
        self.device = device

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:

        # Compute the total loss
        self.optimizer.zero_grad()
        tot_loss = 0

        outputs = self.forward(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        loss = loss_func(outputs_sliced, labels)
        loss.backward()
        tot_loss += loss.item()

        batch_size = inputs.size(0)

        if not self.buffer.is_empty():
            buf_inputs, buf_logits, _, _ = self.buffer.get_data(
                batch_size, transform=transform
            )
            buf_outputs = self.forward(buf_inputs)
            loss_mse = self.lambda_alpha * F.mse_loss(buf_outputs, buf_logits)

            buf_inputs, _, buf_labels, buf_task_id = self.buffer.get_data(
                batch_size, transform=transform
            )
            indices_range = torch.stack(
                [buf_task_id + i for i in range(self.cls_output_dim)], dim=1
            ).to(self.device)
            indices_range = indices_range.clamp(max=buf_outputs.size(1) - 1)
            buf_outputs_sliced = buf_outputs.gather(
                1, indices_range.to(torch.int64)
            )
            loss_ce = self.lambda_beta * loss_func(
                buf_outputs_sliced,
                buf_labels.type(torch.LongTensor).to(self.device),
            )

            # Combine losses and perform a single backward pass
            total_loss = loss_mse + loss_ce
            total_loss.backward()
            tot_loss += total_loss.item()

        self.optimizer.step()
        # Add the new data to the buffer
        self.buffer.add_data(
            examples=not_aug_inputs,
            labels=labels,
            logits=outputs.data,
            task=[task_id] * batch_size,
        )

        return tot_loss
