import numpy as np

from typing import List, Tuple, Dict, Union

class SamplerBase:
    def __init__(self,
                 workers: np.ndarray,
                 seed: int = 42) -> None:
        """
        Args:
            seed: int - seed for random generator
        """
        self.workers = workers
        self.cpu_generator = np.random.default_rng(seed)

    def __call__(self) -> np.ndarray:
        raise NotImplementedError


class FullSampler(SamplerBase):
    def __call__(self) -> np.ndarray:
        return self.workers


class NonUniformSampler(SamplerBase):
    def __init__(self,
                 workers: np.ndarray,
                 probabilities: np.ndarray,
                 seed: int = 42) -> None:
        super().__init__(workers, seed)
        self.probabilities = probabilities

    def __call__(self) -> np.ndarray:
        return self.cpu_generator.choice(self.workers, p=self.probabilities, replace=True)


class NiceSampler(SamplerBase):
    def __init__(self,
                 workers: np.ndarray,
                 batch_size: int,
                 seed: int = 42) -> None:
        super().__init__(workers, seed)
        self.batch_size = batch_size

    def __call__(self) -> np.ndarray:
        return self.cpu_generator.choice(self.workers, self.batch_size, replace=True)


class BlockSampler(SamplerBase):
    def __init__(self,
                 workers: np.ndarray,
                 blocks_mapping: Dict[int, int],
                 blocks_probabilities: np.ndarray,
                 seed: int = 42) -> None:
        """
        Args:
            blocks_mapping: dict[int, int] - mapping from worker to block
            blocks_probabilities: np.ndarray - probabilities of each block
            seed: int - seed for random generator
        """
        super().__init__(workers, seed)
        self.blocks_mapping = blocks_mapping
        self.blocks_probabilities = blocks_probabilities
        if len(self.blocks_probabilities.shape) == 0:
            blocks_number = len(set(blocks_mapping.values()))
            block_sizes = np.array([len([worker for worker in self.workers if self.blocks_mapping[worker] == block]) 
                                    for block in range(blocks_number)])
            self.blocks_probabilities = np.array([block_size / len(self.workers) for block_size in block_sizes])

    def __call__(self) -> np.ndarray:
        sampled_block = self.cpu_generator.choice(len(self.blocks_probabilities), p=self.blocks_probabilities)
        return np.array([worker for worker in self.workers if self.blocks_mapping[worker] == sampled_block])


class StratifiedSampler(SamplerBase):
    def __init__(self,
                 workers: np.ndarray,
                 blocks_mapping: Dict[int, int],
                 seed = 42) -> None:
        super().__init__(workers, seed)
        self.blocks_mapping = blocks_mapping
        self.blocks_number = len(set(blocks_mapping.values()))
        self.blocks = np.array([[worker for worker in self.workers if self.blocks_mapping[worker] == i] 
                                for i in range(self.blocks_number)])
    
    def __call__(self) -> np.ndarray:
        return np.array([self.cpu_generator.choice(block) for block in self.blocks])