import warnings
import numpy as np
import torch
import random


def reservoir(num_seen_examples: int, buffer_size: int) -> int:
    """
    https://github.com/aimagelab/mammoth/blob/master/utils/buffer.py
    Reservoir sampling algorithm.

    Args:
        num_seen_examples: the number of seen examples
        buffer_size: the maximum buffer size

    Returns:
        the target index if the current image is sampled, else -1
    """
    if num_seen_examples < buffer_size:
        return num_seen_examples

    rand = np.random.randint(0, num_seen_examples + 1)
    if rand < buffer_size:
        return rand
    else:
        return -1


class Buffer:
    def __init__(self,
                 capacity,
                 input_size,
                 total_logits_dim,
                 dataset_name='celeba',
                 device='cuda', ):
        self.capacity = capacity
        self.device = device
        self.dataset_name = dataset_name
        self.current_size = 0
        self.num_seen_examples = 0
        self.create_buffer(capacity, input_size, total_logits_dim)

    def create_buffer(self, capacity, input_size, total_logits_dim):
        if self.dataset_name in ['celeba', 'mtfl', 'fairface']:
            self.data = torch.empty(
            (capacity, 3, input_size, input_size), device=self.device)
            # total_logits_dim = num_cls * num_tasks
            self.logits = torch.empty((capacity, total_logits_dim), device=self.device)
            self.labels = torch.empty((capacity), device=self.device)
            self.tasks = torch.empty((capacity), device=self.device)
        elif self.dataset_name == 'physiq':
            self.data = torch.empty((capacity, input_size, 6),device=self.device)
            self.logits = torch.empty((capacity, total_logits_dim), device=self.device)
            self.labels = torch.empty((capacity), device=self.device)
            self.tasks = torch.empty((capacity), device=self.device)
        else:
            raise NotImplementedError(f"Dataset {self.dataset_name} not supported")
            
            
    
    def add_data(self, examples: torch.Tensor,
                 logits: torch.Tensor = None,
                 labels: torch.Tensor = None,
                 task: list = None):
        # Ensure data is on the correct device
        examples = examples.to(self.device)
        if logits is not None:
            logits = logits.to(self.device)
        if labels is not None:
            labels = labels.to(self.device)
        if task is not None:
            task = torch.tensor(task).to(self.device)
        num_new_examples = len(examples)

        for i in range(num_new_examples):
            # TODO: dont use loop to do this as the memory is already preoccupied
            if self.current_size < self.capacity:
                # Directly add the example to the buffer if it's not full
                self.data[self.current_size] = examples[i]
                if logits is not None:
                    self.logits[self.current_size] = logits[i]
                if labels is not None:
                    self.labels[self.current_size] = labels[i]
                if task is not None:
                    self.tasks[self.current_size] = task[i]
                self.current_size += 1
            else:
                # Apply reservoir sampling when the buffer is full
                index = reservoir(self.num_seen_examples, self.capacity)
                if index >= 0:
                    self.data[index] = examples[i]
                    if logits is not None:
                        self.logits[index] = logits[i]
                    if labels is not None:
                        self.labels[index] = labels[i]

            # Update the count of seen examples
            self.num_seen_examples += 1

        return

    def sample(self, n):
        if n > self.current_size:
            indices = random.choices(range(self.current_size), k=n)  # Resample with repetition
        else:
            indices = random.sample(range(self.current_size), n)
        return self.data[indices], self.logits[indices], self.labels[indices], self.tasks[indices]

    def empty(self):
        self.current_size = 0
        self.num_seen_examples = 0
        return

    def __len__(self):
        return self.current_size

    def is_empty(self):
        return self.current_size == 0

    def get_data(self, batch_size, transform):
        data, logits, labels, tasks = self.sample(batch_size)
        data = transform(data)
        return data, logits, labels, tasks

    def get_all_data(self):
        return self.data[:self.current_size], self.logits[:self.current_size], self.labels[:self.current_size], self.tasks[:self.current_size]

# Example usage:
# buffer = Buffer(capacity=1000, device='cuda')
# buffer.add_data(examples_tensor, logits_tensor)
# data, logits = buffer.get_data(batch_size=32, transform=your_transform_function)
