from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import ContinualLearning
from .Buffer import Buffer
from .utils_model import backbone


class ER(ContinualLearning):

    def __init__(
        self,
        encoder: nn.Module,
        lr: float = 0.001,
        buffer_size: int = 500,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        input_size: int = 64,
        z_dim: int = 512,
        dataset_name: str = "celeba",
        device: str = "cuda",
        **kwargs
    ) -> None:
        """
        The ER model maintains a buffer of previously seen examples and uses them to augment the current batch during training.
        """
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(ER, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.buffer = Buffer(
            capacity=buffer_size,
            device=device,
            input_size=input_size,
            total_logits_dim=cls_output_dim * num_tasks,
            dataset_name=dataset_name,
        )
        self.device = device

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.
        """
        return self.encoder(x)

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:
        """
        Compute the loss for the current batch, optionally augmented with data from the buffer.
        """

        real_batch_size = inputs.shape[0]
        self.optimizer.zero_grad()

        outputs = self.encoder(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        loss = loss_func(outputs_sliced, labels)

        if not self.buffer.is_empty():
            buf_inputs, _, buf_labels, buf_task_id = self.buffer.get_data(
                real_batch_size, transform=transform
            )
            buf_outputs = self.encoder(buf_inputs)
            indices_range = torch.stack(
                [buf_task_id + i for i in range(self.cls_output_dim)], dim=1
            ).to(self.device)
            indices_range = indices_range.clamp(max=buf_outputs.size(1) - 1)
            buf_outputs_sliced = buf_outputs.gather(
                1, indices_range.to(torch.int64)
            )
            prev_task_loss = loss_func(
                buf_outputs_sliced,
                buf_labels.type(torch.LongTensor).to(self.device),
            )
            # inputs = torch.cat((inputs, buf_inputs))
            # labels = torch.cat((labels, buf_labels))
            loss += prev_task_loss

        loss.backward()
        self.optimizer.step()

        # Add new data to the buffer
        self.buffer.add_data(
            examples=not_aug_inputs,
            labels=labels[:real_batch_size],
            task=[task_id] * real_batch_size,
        )

        return loss.item()
