# https://github.com/aimagelab/mammoth/blob/master/models/gss.py

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import ContinualLearning
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from typing import Tuple
from .utils_model import backbone


class Buffer:
    """
    The memory buffer for rehearsal methods, adapted for the GSS model.
    """

    def __init__(
        self, capacity: int, device: str, minibatch_size: int, model=None
    ) -> None:
        self.capacity = capacity
        self.device = device
        self.num_seen_examples = 0
        self.attributes = ["examples", "labels", "tasks"]  # Added 'tasks' here
        self.model = model
        self.minibatch_size = minibatch_size
        self.cache = {}
        self.fathom = 0
        self.fathom_mask = None
        self.reset_fathom()
        self.scores = None

    def reset_fathom(self) -> None:
        """
        Resets the fathom and its mask for sequential data retrieval.
        """
        self.fathom = 0
        if hasattr(self, "examples"):
            self.fathom_mask = torch.randperm(
                min(self.num_seen_examples, self.examples.shape[0])
            )
        else:
            self.fathom_mask = torch.randperm(self.num_seen_examples)

    def get_grad_score(
        self,
        batch_x: torch.Tensor,
        batch_y: torch.Tensor,
        X: torch.Tensor,
        Y: torch.Tensor,
        indices: torch.Tensor,
    ) -> float:
        """
        Computes the gradient score between the current batch and the buffer data.
        """
        g = self.model.get_grads(batch_x, batch_y)
        G = []

        for x, y, idx in zip(X, Y, indices):
            if idx in self.cache:
                grd = self.cache[idx]
            else:
                grd = self.model.get_grads(x.unsqueeze(0), y.unsqueeze(0))
                self.cache[idx] = grd
            G.append(grd)

        G = torch.cat(G).to(g.device)
        c_score = 0
        grads_at_a_time = 5

        for it in range(int(np.ceil(G.shape[0] / grads_at_a_time))):
            tmp = (
                F.cosine_similarity(
                    g,
                    G[it * grads_at_a_time : (it + 1) * grads_at_a_time],
                    dim=1,
                )
                .max()
                .item()
                + 1
            )
            c_score = max(c_score, tmp)

        return c_score

    def functional_reservoir(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        batch_c: float,
        bigX: torch.Tensor = None,
        bigY: torch.Tensor = None,
        indices: torch.Tensor = None,
    ) -> Tuple[int, float]:
        """
        Implements a modified reservoir sampling strategy for gradient-based sample selection.
        """
        if self.num_seen_examples < self.capacity:
            return self.num_seen_examples, batch_c
        elif batch_c < 1:
            single_c = self.get_grad_score(
                x.unsqueeze(0), y.unsqueeze(0), bigX, bigY, indices
            )
            s = self.scores.cpu().numpy()
            i = np.random.choice(
                np.arange(0, self.capacity), size=1, p=s / s.sum()
            )[0]
            rand = np.random.rand(1)[0]

            if rand < s[i] / (s[i] + single_c):
                return i, single_c

        return -1, 0

    def init_tensors(
        self,
        examples: torch.Tensor,
        labels: torch.Tensor,
        tasks: torch.Tensor = None,
    ) -> None:
        """
        Initializes the buffer tensors based on the input examples, labels, and tasks.
        """
        for attr_str in self.attributes:
            attr = eval(attr_str)
            if attr is not None and not hasattr(self, attr_str):
                typ = (
                    torch.int64
                    if attr_str.endswith("els") or attr_str.endswith("sks")
                    else torch.float32
                )
                setattr(
                    self,
                    attr_str,
                    torch.zeros(
                        (self.capacity, *attr.shape[1:]),
                        dtype=typ,
                        device=self.device,
                    ),
                )

        self.scores = torch.zeros(
            (self.capacity), dtype=torch.float32, device=self.device
        )

    def add_data(
        self,
        examples: torch.Tensor,
        labels: torch.Tensor = None,
        tasks: torch.Tensor = None,
    ) -> None:
        """
        Adds data to the buffer along with labels and tasks according to the modified reservoir strategy.
        """
        if not hasattr(self, "examples"):
            self.init_tensors(examples, labels, tasks)

        if self.num_seen_examples > 0:
            bigX, bigY, bigT, indices = self.get_data(
                min(self.minibatch_size, self.num_seen_examples),
                give_index=True,
                random=True,
            )
            c = self.get_grad_score(examples, labels, bigX, bigY, indices)
        else:
            bigX, bigY, bigT, indices = None, None, None, None
            c = 0.1

        for i in range(examples.shape[0]):
            index, score = self.functional_reservoir(
                examples[i], labels[i], c, bigX, bigY, indices
            )
            self.num_seen_examples += 1

            if index >= 0:
                self.examples[index] = examples[i].to(self.device)
                if labels is not None:
                    self.labels[index] = labels[i].to(self.device)
                if tasks is not None:
                    self.tasks[index] = tasks[i].to(
                        self.device
                    )  # Handle tasks similarly
                self.scores[index] = score

                if index in self.cache:
                    del self.cache[index]

    def drop_cache(self) -> None:
        """
        Clears the cache used for storing gradients.
        """
        self.cache = {}

    def get_data(
        self,
        size: int,
        transform: transforms = None,
        give_index: bool = False,
        random: bool = False,
    ) -> Tuple:
        """
        Randomly samples a batch of size items from the buffer.
        """
        if size > self.examples.shape[0]:
            size = self.examples.shape[0]

        if random:
            choice = np.random.choice(
                min(self.num_seen_examples, self.examples.shape[0]),
                size=min(size, self.num_seen_examples),
                replace=False,
            )
        else:
            choice = np.arange(
                self.fathom,
                min(
                    self.fathom + size,
                    self.examples.shape[0],
                    self.num_seen_examples,
                ),
            )
            choice = self.fathom_mask[choice]
            self.fathom += len(choice)

            if (
                self.fathom >= self.examples.shape[0]
                or self.fathom >= self.num_seen_examples
            ):
                self.fathom = 0

        if transform is None:

            def transform(x):
                return x

        ret_tuple = (
            torch.stack(
                [transform(ee.cpu()) for ee in self.examples[choice]]
            ).to(self.device),
        )

        for attr_str in self.attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str)
                ret_tuple += (attr[choice],)

        if give_index:
            ret_tuple += (choice,)

        return ret_tuple

    def is_empty(self) -> bool:
        """
        Returns true if the buffer is empty, false otherwise.
        """
        return self.num_seen_examples == 0

    def get_all_data(self, transform: transforms = None) -> Tuple:
        """
        Returns all items in the buffer.
        """
        if transform is None:

            def transform(x):
                return x

        ret_tuple = (
            torch.stack([transform(ee.cpu()) for ee in self.examples]).to(
                self.device
            ),
        )

        for attr_str in self.attributes[1:]:
            if hasattr(self, attr_str):
                attr = getattr(self, attr_str)
                ret_tuple += (attr,)

        return ret_tuple

    def empty(self) -> None:
        """
        Empties the buffer, resetting all attributes.
        """
        for attr_str in self.attributes:
            if hasattr(self, attr_str):
                delattr(self, attr_str)
        self.num_seen_examples = 0


class GSS(ContinualLearning):
    NAME = "gss"

    def __init__(
        self,
        encoder: nn.Module,
        lr: float = 0.001,
        buffer_size: int = 500,
        gss_minibatch_size: int = 50,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        input_size: int = 64,
        batch_num: int = 1,
        dataset_name: str = "celeba",
        z_dim: int = 512,
        device: str = "cuda",
        **kwargs
    ) -> None:
        """
        Gradient-based sample selection for online continual learning.
        """
        encoder = backbone(encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim)
        super(GSS, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.buffer = Buffer(
            capacity=buffer_size,
            device=device,
            minibatch_size=gss_minibatch_size,
            model=self,
        )
        self.batch_num = batch_num
        self.device = device
        self.gss_minibatch_size = gss_minibatch_size
        # we assume cross entropy loss, warning
        Warning("Assuming cross entropy loss")
        self.loss_func = (
            nn.CrossEntropyLoss()
        )  # TODO: assuming it is cross entropy

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.
        """
        return self.encoder(x)

    def get_grads(
        self, inputs: torch.Tensor, labels: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute gradients for the given inputs and labels.
        """
        self.encoder.eval()
        self.optimizer.zero_grad()
        outputs = self.encoder(inputs)
        loss = self.loss_func(outputs, labels)
        loss.backward()
        grads = self.encoder.get_grads().clone().detach()
        self.optimizer.zero_grad()
        self.encoder.train()
        if len(grads.shape) == 1:
            grads = grads.unsqueeze(0)
        return grads

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:
        """
        Compute the loss with gradient-based sample selection.
        """

        real_batch_size = inputs.shape[0]

        # Drop cache and reset fathom
        self.buffer.drop_cache()
        self.buffer.reset_fathom()

        for _ in range(self.batch_num):
            self.optimizer.zero_grad()
            outputs = self.encoder(inputs)
            outputs_sliced = outputs[
                :,
                task_id * self.cls_output_dim : task_id * self.cls_output_dim
                + self.cls_output_dim,
            ]
            loss = loss_func(outputs_sliced, labels)
            if not self.buffer.is_empty():
                buf_inputs, buf_labels, buf_tasks = self.buffer.get_data(
                    size=self.gss_minibatch_size, transform=transform
                )
                buf_outputs = self.encoder(buf_inputs)
                indices_range = torch.stack(
                    [buf_tasks + i for i in range(self.cls_output_dim)], dim=1
                ).to(self.device)
                indices_range = indices_range.clamp(max=buf_outputs.size(1) - 1)
                buf_outputs_sliced = buf_outputs.gather(
                    1, indices_range.to(torch.int64)
                )
                prev_task_loss = loss_func(
                    buf_outputs_sliced,
                    buf_labels.type(torch.LongTensor).to(self.device),
                )
                loss += prev_task_loss

            loss.backward()
            self.optimizer.step()

        # Add new data to the buffer
        self.buffer.add_data(
            examples=not_aug_inputs,
            labels=labels[:real_batch_size],
            tasks=torch.tensor([task_id] * real_batch_size).to(self.device),
        )

        return loss.item()
