import math
import torch
import random
import torch.distributed as dist
from torch.utils.data import Sampler, DistributedSampler
from collections import defaultdict


class RASampler(torch.utils.data.Sampler):
    """Sampler that restricts data loading to a subset of the dataset for distributed,
    with repeated augmentation.
    It ensures that different each augmented version of a sample will be visible to a
    different process (GPU).
    Heavily based on 'torch.utils.data.DistributedSampler'.

    This is borrowed from the DeiT Repo:
    https://github.com/facebookresearch/deit/blob/main/samplers.py
    """

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available!")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available!")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
        self.shuffle = shuffle
        self.seed = seed
        self.repetitions = repetitions

    def __iter__(self):
        if self.shuffle:
            # Deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # Add extra samples to make it evenly divisible
        indices = [ele for ele in indices for i in range(self.repetitions)]
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # Subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices[: self.num_selected_samples])

    def __len__(self):
        return self.num_selected_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


class FixedClassSampler(Sampler):
    def __init__(self, dataset, num_samples_per_class, seed=None):
        self.dataset = dataset
        self.num_samples_per_class = num_samples_per_class
        self.seed = seed
        self.class_indices = self._get_class_indices()
        self.indices = self._sample_indices()
   
    def _get_class_indices(self):
        # Create a dictionary to store the indices for each class
        class_indices = defaultdict(list)
       
        # Populate class indices
        for idx, (img, label) in enumerate(self.dataset.samples):
            class_indices[label].append(idx)
       
        return class_indices
   
    def _sample_indices(self):
        # Sample fixed number of images from each class
        selected_indices = []
        if self.seed is not None:
            random.seed(self.seed)
       
        for label, indices in self.class_indices.items():
            if len(indices) < self.num_samples_per_class:
                raise ValueError(f"Not enough samples for class {label}")
            selected_indices.extend(random.sample(indices, self.num_samples_per_class))
       
        return selected_indices
   
    def __iter__(self):
        return iter(self.indices)
   
    def __len__(self):
        return len(self.indices)


# something wrong
class FixedClassDistributedSampler(Sampler):
    def __init__(self, dataset, num_samples_per_class, seed=None):
        self.dataset = dataset
        self.num_samples_per_class = num_samples_per_class
        self.seed = seed
        self.class_indices = self._get_class_indices()
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        self.total_samples = self.num_samples_per_class * len(self.class_indices)
        self.num_samples_per_rank = self.total_samples // self.world_size
        self.indices = self._sample_indices()
        
    def _get_class_indices(self):
        # Create a dictionary to store the indices for each class
        class_indices = defaultdict(list)
        for idx, (_, label) in enumerate(self.dataset.samples):
            class_indices[label].append(idx)
        return class_indices

    def _sample_indices(self):
        # Sample fixed number of images from each class
        selected_indices = []
        if self.seed is not None:
            random.seed(self.seed)
        for label, indices in self.class_indices.items():
            if len(indices) < self.num_samples_per_class:
                raise ValueError(f"Not enough samples for class {label}")
            selected_indices.extend(random.sample(indices, self.num_samples_per_class))
        
        # Shuffle the indices
        if self.seed is not None:
            random.seed(self.seed)
        random.shuffle(selected_indices)
        return selected_indices

    def __iter__(self):
        # Partition indices for the current rank
        start = self.rank * self.num_samples_per_rank
        end = start + self.num_samples_per_rank
        return iter(self.indices[start:end])

    def __len__(self):
        return self.num_samples_per_rank


class DistributedFixedClassSampler(Sampler):
    def __init__(self, dataset, num_samples_per_class, num_replicas=None, rank=None, seed=None, shuffle=True):
        self.dataset = dataset
        self.num_samples_per_class = num_samples_per_class
        self.seed = seed
        self.shuffle = shuffle
        
        # Get class indices and verify we have enough samples per class
        self.class_indices = self._get_class_indices()
        self._verify_class_counts()
        
        # Distributed setup
        self.num_replicas = num_replicas if num_replicas is not None else torch.distributed.get_world_size()
        self.rank = rank if rank is not None else torch.distributed.get_rank()
        
        # Generate indices for each GPU that maintains class balance
        self.indices_per_gpu = self._get_indices_per_gpu()
 
    def _get_class_indices(self):
        class_indices = defaultdict(list)
        for idx, (img, label) in enumerate(self.dataset.samples):
            class_indices[label].append(idx)
        return class_indices
 
    def _verify_class_counts(self):
        for label, indices in self.class_indices.items():
            if len(indices) < self.num_samples_per_class:
                raise ValueError(f"Class {label} has only {len(indices)} samples, but requested {self.num_samples_per_class}")
 
    def _get_indices_per_gpu(self):
        if self.seed is not None:
            random.seed(self.seed)
        
        # First sample the required number of indices per class
        sampled_indices = []
        for label, indices in self.class_indices.items():
            if self.shuffle:
                selected = random.sample(indices, self.num_samples_per_class)
            else:
                selected = indices[:self.num_samples_per_class]
            sampled_indices.append((label, selected))
        
        # Then distribute these sampled indices across GPUs maintaining class balance
        gpu_indices = [[] for _ in range(self.num_replicas)]
        
        for label, indices in sampled_indices:
            # Split this class's indices across GPUs
            per_gpu = len(indices) // self.num_replicas
            for i in range(self.num_replicas):
                start = i * per_gpu
                end = start + per_gpu if i != self.num_replicas - 1 else len(indices)
                gpu_indices[i].extend(indices[start:end])
        
        # Shuffle the indices within each GPU to mix classes
        if self.shuffle:
            random.shuffle(gpu_indices[self.rank])
        
        return gpu_indices[self.rank]
 
    def __iter__(self):
        return iter(self.indices_per_gpu)
 
    def __len__(self):
        return len(self.indices_per_gpu)