import warnings
import numpy as np
import torch
import random

def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int:
    """
    Ring buffer algorithm to determine the index for storing new examples.

    Args:
        num_seen_examples: the number of seen examples
        buffer_portion_size: the portion of the buffer assigned to the task
        task: the task identifier

    Returns:
        The index in the buffer where the new example should be placed.
    """
    return num_seen_examples % buffer_portion_size + task * buffer_portion_size

class RingBuffer:
    def __init__(self, capacity, device='cuda', n_tasks: int = 1):
        self.capacity = capacity
        self.device = device
        self.n_tasks = n_tasks
        self.buffer_portion_size = capacity // n_tasks
        self.data = torch.empty((capacity, 3, 224, 224), device=device)
        self.logits = torch.empty((capacity, 2), device=device)
        self.labels  = torch.empty((capacity), device=device)
        self.num_seen_examples = [0] * n_tasks  # Track seen examples per task

    def add_data(self, examples: torch.Tensor,
                 logits: torch.Tensor = None,
                 labels: torch.Tensor = None,
                 task: int = 0) -> None:
        """
        Add data to the buffer using the ring buffer strategy.

        Args:
            examples: Tensor of examples to add.
            logits: Tensor of logits corresponding to the examples.
            labels: Tensor of labels corresponding to the examples.
            task: Task identifier for the current data.
        """
        # Ensure data is on the correct device
        examples = examples.to(self.device)
        if logits is not None:
            logits = logits.to(self.device)
        if labels is not None:
            labels = labels.to(self.device)

        num_new_examples = len(examples)

        for i in range(num_new_examples):
            # Calculate the index using the ring buffer strategy
            index = ring(self.num_seen_examples[task], self.buffer_portion_size, task)
            self.data[index] = examples[i]
            if logits is not None:
                self.logits[index] = logits[i]
            if labels is not None:
                self.labels[index] = labels[i]

            # Update the count of seen examples for the specific task
            self.num_seen_examples[task] += 1

        return

    def sample(self, n):
        indices = random.sample(range(sum(self.num_seen_examples)), n)
        return self.data[indices], self.logits[indices], self.labels[indices]

    def __len__(self):
        return sum(self.num_seen_examples)

    def is_empty(self):
        return sum(self.num_seen_examples) == 0

    def get_data(self, batch_size, transform):
        data, logits, labels = self.sample(batch_size)
        data = transform(data)
        return data, logits, labels

# Example usage:
# buffer = Buffer(capacity=1000, device='cuda', n_tasks=10)
# buffer.add_data(examples_tensor, logits_tensor, task=0)
# data, logits = buffer.get_data(batch_size=32, transform=your_transform_function)
