import random

import torch


class Reservoir:
    """
    A class that implements a Reservoir Sampling algorithm for continual learning.
    """

    def __init__(self, buffer_size: int, device: torch.device):
        self.buffer = {"labels": None}
        self.buffer_size = buffer_size
        self.device = device
        self.N = -1

    def __len__(self):
        return len(self.buffer["labels"]) if self.buffer["labels"] is not None else 0

    def add(self, inputs, labels):
        """Adds new data to the reservoir."""
        new_vals = inputs if isinstance(inputs, dict) else {"inputs": inputs}
        new_vals["labels"] = labels

        for idx in range(new_vals["labels"].shape[0]):
            self.N += 1
            for key, val in new_vals.items():
                if len(self) < self.buffer_size:
                    if len(self) == 0:
                        self.buffer[key] = val[idx].unsqueeze(0)
                    else:
                        self.buffer[key] = torch.cat(
                            (self.buffer[key], val[idx].unsqueeze(0)), 0
                        )

                else:
                    random_index = random.randint(0, self.N)
                    if random_index < self.buffer_size:
                        self.buffer[key][random_index] = val[idx]

    def sample(self, batch_size: int):
        """Samples a batch of data from the reservoir."""
        if len(self) >= batch_size:
            with torch.no_grad():
                indices = random.sample(range(len(self)), k=batch_size)
                return {
                    key: val[indices].to(self.device)
                    for key, val in self.buffer.items()
                    if not key == "labels"
                }, self.buffer["labels"][indices].view(-1).to(self.device)
        else:
            raise ValueError(
                "Batch size is larger than the current size of the reservoir."
            )


class GreedyBalancingSampler:
    """
    Greedy Balancing Sampler implementation
    """

    def __init__(self, buffer_size: int, device: torch.device):
        self.buffer = {"inputs": None, "labels": None}
        self.buffer_size = buffer_size
        self.device = device
        self.counter = [0]

    def __len__(self):
        return len(self.buffer["labels"]) if self.buffer["labels"] is not None else 0

    def add(self, inputs, labels):
        """Adds new data to the reservoir."""
        for x, y in zip(inputs, labels):
            x, y = x.unsqueeze(0), y.unsqueeze(0)
            if y >= len(self.counter):
                self.counter = self.counter + [0] * (y + 1 - len(self.counter))
            k_c = (
                self.buffer_size / len(self.counter)
                if sum(self.counter) != 0
                else self.buffer_size
            )
            if self.counter[y] < k_c:
                if sum(self.counter) > self.buffer_size:
                    max_label = torch.argmax(torch.tensor(self.counter))
                    index = torch.argwhere(self.buffer["labels"].view(-1) == max_label)
                    remove_index = index[torch.randint(0, len(index), (1,)).item()]
                    (
                        self.buffer["inputs"][remove_index, :],
                        self.buffer["labels"][remove_index],
                    ) = (x, y)
                    self.counter[max_label] -= 1

                else:
                    if len(self) == 0:
                        self.buffer["inputs"] = x
                        self.buffer["labels"] = y
                    else:
                        self.buffer["inputs"] = torch.cat((self.buffer["inputs"], x), 0)
                        self.buffer["labels"] = torch.cat((self.buffer["labels"], y), 0)

            self.counter[y] += 1

    def sample(self, batch_size: int):
        """Samples a batch of data from the reservoir."""
        if len(self) >= batch_size:
            with torch.no_grad():
                indices = random.sample(range(len(self)), k=batch_size)
                return self.buffer["inputs"][indices].to(self.device), self.buffer[
                    "labels"
                ][indices].view(-1).to(self.device)
        else:
            raise ValueError(
                "Batch size is larger than the current size of the reservoir."
            )
