from abc import ABC, abstractmethod
from typing import List, Any
import random
import numpy as np
import warnings

class BaseSampler(ABC):
    def __init__(self, sample_size: int):
        self.sample_size = sample_size  # Add sample_size as an attribute
        self.desc = ""

    @abstractmethod
    def sample(self, examples: List[Any]) -> List[Any]:
        pass


class QuantileSampler(BaseSampler):
    def __init__(self, sample_size: int, num_quantiles: int = 10):
        super().__init__(sample_size)
        self.num_quantiles = num_quantiles
        self.desc = f"quantile with {num_quantiles} quantiles"

    def sample(self, examples: List[Any]) -> List[Any]:
        if not examples or self.sample_size <= 0:
            return []

        sorted_examples = sorted(examples, key=lambda x: np.mean(x.activation_values))
        per_quantile = self.sample_size // self.num_quantiles
        remainder = self.sample_size % self.num_quantiles

        sampled = []
        quantile_indices = np.linspace(0, len(sorted_examples), self.num_quantiles + 1, dtype=int)

        for i in range(self.num_quantiles):
            start = quantile_indices[i]
            end = quantile_indices[i + 1]
            quantile = sorted_examples[start:end]
            n = per_quantile + (1 if i < remainder else 0)
            sampled.extend(random.sample(quantile, min(n, len(quantile))))

        return sampled[:self.sample_size]


class TopSampler(BaseSampler):
    def __init__(
        self,
        sample_size: int,
        top_quantile: float = 0.75,
        random_sample: bool = True
    ):
        """
        Samples top examples either deterministically or randomly.
        
        Args:
            sample_size: Number of examples to sample
            top_quantile: Quantile threshold for considering top examples (0-1)
            random_sample: Whether to sample randomly from top examples (True) 
                          or take top N deterministically (False)
        """
        super().__init__(sample_size)
        self.top_quantile = top_quantile
        self.random_sample = random_sample
        self.desc = f"top_q{top_quantile}_{'random' if random_sample else 'sorted'}"

    def sample(self, examples: List[Any]) -> List[Any]:
        if not examples or self.sample_size <= 0:
            return []

        # Get activation values and calculate threshold
        activations = [np.max(x.activation_values) for x in examples]
        threshold = np.quantile(activations, self.top_quantile)
        
        # Filter examples above threshold
        top_examples = [ex for ex, act in zip(examples, activations) if act >= threshold]

        if self.random_sample:
            # Random sampling from top examples
            if len(top_examples) < self.sample_size:
                warnings.warn(
                    f"Only {len(top_examples)} examples meet top {self.top_quantile} quantile, "
                    f"but requested {self.sample_size} samples"
                )
            return random.sample(top_examples, min(len(top_examples), self.sample_size))
        else:
            # Deterministic sorting and slicing
            sorted_examples = sorted(
                top_examples, 
                key=lambda x: np.mean(x.activation_values), 
                reverse=True
            )
            return sorted_examples[:self.sample_size]


class TopAndRandomSampler(BaseSampler):
    def __init__(self, sample_size: int, ratio: float = 0.5, **top_sampler_args):
        super().__init__(sample_size)
        self.ratio = ratio
        self.desc = f"top and random with ratio {ratio}"
        self.top_sampler_args = top_sampler_args

    def sample(self, examples: List[Any]) -> List[Any]:
        if not examples or self.sample_size <= 0:
            return []

        n_top = int(self.sample_size * self.ratio)
        top_sampler = TopSampler(sample_size=n_top, **self.top_sampler_args)
        top = top_sampler.sample(examples)

        remaining = [ex for ex in examples if ex not in top]
        random_sampler = RandomSampler(sample_size=self.sample_size - n_top)
        random_samples = random_sampler.sample(remaining)

        return top + random_samples


class RandomSampler(BaseSampler):
    def __init__(self, sample_size: int):
        super().__init__(sample_size)
        self.desc = "random"
        
    def sample(self, examples: List[Any]) -> List[Any]:
        if not examples or self.sample_size <= 0:
            return []

        return random.sample(examples, min(self.sample_size, len(examples)))