from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import ContinualLearning
from .Buffer import Buffer
from .utils_model import backbone


class FDR(ContinualLearning):

    def __init__(
        self,
        encoder: nn.Module,
        lr: float = 0.001,
        buffer_size: int = 500,
        alpha: float = 0.1,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        input_size: int = 64,
        z_dim: int = 512,
        dataset_name: str = "celeba",
        device: str = "cuda",
        **kwargs
    ) -> None:
        """
        https://github.com/aimagelab/mammoth/blob/master/models/fdr.py
        The FDR model maintains a buffer of previously seen examples and applies function distance regularization.
        """
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(FDR, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.buffer = Buffer(
            capacity=buffer_size,
            device=device,
            input_size=input_size,
            total_logits_dim=cls_output_dim * num_tasks,
            dataset_name=dataset_name,
        )
        self.alpha = alpha
        self.device = device
        self.soft = torch.nn.Softmax(dim=1)
        self.current_task = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.
        """
        return self.encoder(x)

    def end_task(
        self,
        dataloader: torch.utils.data.DataLoader,
        task_name: str,
        task_id: int,
    ) -> None:
        """
        Finalize the task by updating the buffer with selected examples.

        Args:
            dataloader: A regular DataLoader containing the training data for the current task.
        """

        # Determine the number of examples to store per task
        examples_per_task = (
            self.buffer.capacity // (self.current_task + 1)
            if self.current_task > 0
            else self.buffer.capacity
        )

        if self.current_task > 0:
            # Retrieve all data from the buffer
            buf_x, buf_log, buf_tl, buf_tk = self.buffer.get_all_data()
            self.buffer.empty()  # Empty the buffer to refill it with selected data

            # NOTE: only works for each task have same number of unique task labels and task keys
            for ttl in buf_tl.unique():  # Loop over unique task labels
                for tsk in buf_tk.unique():  # Nested loop over unique task keys
                    # Select examples where both task label and task key match
                    idx = (buf_tl == ttl) & (buf_tk == tsk)
                    ex, log, tasklab, buf_t = (
                        buf_x[idx],
                        buf_log[idx],
                        buf_tl[idx],
                        buf_tk[idx],
                    )
                    first = min(ex.shape[0], examples_per_task)
                    self.buffer.add_data(
                        examples=ex[:first],
                        logits=log[:first],
                        labels=tasklab[:first],
                        task=buf_t[:first],
                    )

        counter = 0
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                inputs = data["image"].to(self.device)
                cur_task_y = (
                    data[task_name].type(torch.LongTensor).to(self.device)
                )

                # Move data to the correct device
                inputs = inputs.to(self.device)

                # Forward pass through the network to get logits
                outputs = self.encoder(inputs)

                # Break the loop if we've filled the buffer for this task
                if examples_per_task - counter <= 0:
                    break

                # Add examples and corresponding logits to the buffer\
                self.buffer.add_data(
                    examples=inputs[: (examples_per_task - counter)],
                    logits=outputs[: (examples_per_task - counter)],
                    labels=cur_task_y[: (examples_per_task - counter)],
                    task=[task_id] * (examples_per_task - counter),
                )
                counter += len(inputs)

        self.current_task += 1

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:
        """
        Compute the loss for the current batch, including the function distance regularization if buffer data is available.
        """

        real_batch_size = inputs.shape[0]
        self.optimizer.zero_grad()

        # Compute loss on the current inputs
        outputs = self.encoder(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        loss = loss_func(outputs_sliced, labels)

        if not self.buffer.is_empty():
            # Get buffer data and concatenate with the current batch
            buf_inputs, buf_logits, _, _ = self.buffer.get_data(
                batch_size=real_batch_size, transform=transform
            )
            buf_outputs = self.encoder(buf_inputs)

            # Compute the regularization loss
            regularization_loss = torch.norm(
                self.soft(buf_outputs) - self.soft(buf_logits), 2, 1
            ).mean()
            loss += self.alpha * regularization_loss

        loss.backward()
        self.optimizer.step()

        # Add new data to the buffer
        # self.buffer.add_data(examples=not_aug_inputs,
        #                      labels=labels[:real_batch_size])

        return loss.item()
