import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler


class RandomSampler(Sampler):
    """
    Implements a random sampler for PyTorch datasets.
    
    The `RandomSampler` class is a PyTorch `Sampler` implementation that provides random sampling of a dataset, with optional support for distributed training and weighted sampling.
    
    The sampler supports the following features:
    - Random sampling of the dataset with or without replacement.
    - Deterministic shuffling of the dataset across epochs using a fixed seed.
    - Distributed training support, where the dataset is split across multiple replicas.
    - Weighted sampling, where each sample is assigned a weight that affects its probability of being selected.
    - Ability to restore the sampler state from a previous iteration, allowing for resuming training.
    
    The sampler can be used as part of a PyTorch `DataLoader` to provide random access to the dataset during training.
    """

    def __init__(self, dataset=None, batch_size=0, num_iter=None, restore_iter=0,
                 weights=None, replacement=True, seed=0, shuffle=True, num_replicas=None, rank=None):
        """
        Initializes a RandomSampler object that can be used to sample data from a dataset.
        
        Args:
            dataset (Dataset): The dataset to sample from.
            batch_size (int): The batch size for sampling.
            num_iter (int, optional): The number of iterations to sample for.
            restore_iter (int, optional): The iteration to restore from.
            weights (Tensor, optional): Weights to use for sampling.
            replacement (bool, optional): Whether to sample with replacement.
            seed (int, optional): The seed to use for the random number generator.
            shuffle (bool, optional): Whether to shuffle the data before sampling.
            num_replicas (int, optional): The number of replicas to use for distributed sampling.
            rank (int, optional): The rank of the current replica for distributed sampling.
        """
        
        super(RandomSampler, self).__init__(dataset)
        self.dist = dist.is_initialized()
        if self.dist:
            self.num_replicas = dist.get_world_size()
            self.rank = dist.get_rank()
        else:
            self.num_replicas = 1
            self.rank = 0
        if num_replicas is not None:
            self.num_replicas = num_replicas
        if rank is not None:
            self.rank = rank
        self.dataset = dataset
        self.batch_size = batch_size * self.num_replicas
        self.num_samples = num_iter * self.batch_size
        self.restore = restore_iter * self.batch_size
        self.weights = weights
        self.replacement = replacement
        self.seed = seed
        self.shuffle = shuffle

    def __iter__(self):
        """
        Implements the __iter__ method for a PyTorch dataset sampler.
        
        This sampler deterministically shuffles the dataset indices, and optionally applies weights to the sampling. It also supports subsampling the dataset based on the current process rank and number of replicas.
        
        Args:
            dataset (torch.utils.data.Dataset): The dataset to sample from.
            num_samples (int): The number of samples to draw from the dataset.
            replacement (bool): Whether to sample with replacement.
            weights (torch.Tensor, optional): Sampling probabilities (proportional to element weights).
            seed (int, optional): The seed for the random number generator.
            shuffle (bool, optional): Whether to shuffle the dataset indices.
            num_replicas (int, optional): The number of replicas (processes) in the distributed setting.
            rank (int, optional): The rank of the current process in the distributed setting.
            restore (int, optional): The number of samples to restore from the beginning of the dataset.
        """

        # deterministically shuffle
        g = torch.Generator()
        g.manual_seed(self.seed)
        if self.shuffle:
            if self.weights is None:
                n = len(self.dataset)
                epochs = self.num_samples // n + 1
                indices = []
                for e in range(epochs):
                    g = torch.Generator()
                    g.manual_seed(self.seed + e)
                    # drop last
                    indices.extend(torch.randperm(len(self.dataset), generator=g).tolist()[:n - n % self.batch_size])
                indices = indices[:self.num_samples]
                # indices = torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()
            else:
                indices = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=g).tolist()
        else:
            raise NotImplementedError('No shuffle has not been implemented.')

        # subsample
        indices = indices[self.restore + self.rank:self.num_samples:self.num_replicas]

        return iter(indices)

    def __len__(self):
        """
        Returns the number of samples in the dataset, adjusted for the number of replicas.
        """

        return (self.num_samples - self.restore) // self.num_replicas

    def set_epoch(self, epoch: int) -> None:
        """
        Sets the epoch for the sampler.
        
        Args:
            epoch (int): The epoch to set.
        """

        self.seed = epoch

    def set_weights(self, weights: torch.Tensor) -> None:
        """
        Sets the weights for the sampler.
        
        Args:
            weights (torch.Tensor): The weights to set for the sampler.
        """

        self.weights = weights