from torch.utils.data import Sampler,DistributedSampler
import random

import torch
import numpy as np
from collections import defaultdict

class DistributedProportionalSampler(DistributedSampler):
    """
    每个进程独立采样：每类取一个 + 按全局比例填满，适用于DistributedDataParallel。
    """

    def __init__(self, labels, batch_size, num_classes=2, num_replicas=None, rank=None, shuffle=True):
        self.labels = labels
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.shuffle = shuffle

        # 分布式环境信息
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_initialized():
                raise RuntimeError("Requires torch.distributed to be initialized")
            rank = torch.distributed.get_rank()

        self.num_replicas = num_replicas
        self.rank = rank

        # 构建每个类别的索引
        self.class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.class_indices[int(label)].append(idx)

        self.epoch = 0

        # 计算总样本数（每个 rank 大致相同）
        self.total_size = self._calculate_total_size()

    def _calculate_total_size(self):
        max_batches = min(
            len(self.labels) // self.batch_size,
            min(len(v) for v in self.class_indices.values())
        )
        # 所有进程统一长度
        return max_batches * self.batch_size

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        # 为不同 epoch 打乱
        random.seed(self.epoch + self.rank)

        class_indices = {k: v.copy() for k, v in self.class_indices.items()}
        if self.shuffle:
            for v in class_indices.values():
                random.shuffle(v)

        batches = []
        while True:
            batch = []

            # 每类选一个
            for cls in range(self.num_classes):
                if len(class_indices[cls]) == 0:
                    return iter([idx for batch in batches for idx in batch])  # 提前终止
                batch.append(class_indices[cls].pop())

            # 填充剩余
            remaining = self.batch_size - self.num_classes
            available = sum(class_indices.values(), [])
            if len(available) < remaining:
                return iter([idx for batch in batches for idx in batch])  # 不足时终止
            extra = random.sample(available, remaining)

            # 移除 extra
            for idx in extra:
                class_indices[int(self.labels[idx])].remove(idx)

            batch.extend(extra)
            random.shuffle(batch)
            batches.append(batch)

            # 限制最大样本数
            if len(batches) * self.batch_size >= self.total_size:
                break

        # 所有进程分片
        indices = [idx for batch in batches for idx in batch]
        total_samples = len(indices)
        assert total_samples % self.num_replicas == 0, "total samples not divisible"
        per_rank = total_samples // self.num_replicas

        start = self.rank * per_rank
        end = start + per_rank
        return iter(indices[start:end])

    def __len__(self):
        return self.total_size // self.num_replicas
    
class ProportionalSampler(Sampler):
    '''#每次构建一个 batch 时：
    先从每个类别中各选 1 个样本，确保所有类都在；
    然后根据整体 label 比例填充剩余位置；
    把这些样本组成 batch，返回。'''
    def __init__(self, labels, batch_size, num_classes=4):
        self.labels = labels
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.class_indices = defaultdict(list)

        # 构建每个类别的索引列表
        for idx, label in enumerate(labels):
            self.class_indices[int(label)].append(idx)

        # 所有样本的索引
        self.all_indices = list(range(len(labels)))

    def __iter__(self):
        # 复制并打乱每类索引
        class_indices = {k: v.copy() for k, v in self.class_indices.items()}
        for v in class_indices.values():
            random.shuffle(v)

        batches = []
        while True:
            batch = []

            # Step 1: 从每类中各取一个
            for cls in range(self.num_classes):
                if len(class_indices[cls]) == 0:
                    batch.append(random.choice(self.class_indices[cls]))  # 随机选一个
                else:
                    batch.append(class_indices[cls].pop())

            # Step 2: 填满剩余的 batch_size - num_classes
            remaining_slots = self.batch_size - self.num_classes
            available = sum(class_indices.values(), [])
            if len(available) < remaining_slots:
                break

            extra = random.sample(available, remaining_slots)
            batch.extend(extra)

            # 从 class_indices 中移除这些 extra
            for idx in extra:
                cls = int(self.labels[idx])
                class_indices[cls].remove(idx)

            random.shuffle(batch)
            batches.append(batch)

        return iter([idx for batch in batches for idx in batch])
    
    def __len__(self):
        max_possible_batches = min(
        len(self.labels) // self.batch_size,
            min(len(v) for v in self.class_indices.values())  # 每类最多支持的 batch 数
        )
        return max_possible_batches
    #     return len(self.labels)
    
class StratifiedBatchSampler(Sampler):
    def __init__(self, labels, batch_size, num_classes=4, min_samples_per_class=5):
        self.labels = labels
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.min_samples_per_class = min_samples_per_class
        # self.used = []

        # 构建每个类别对应的样本索引
        self.class_indices = {i: [] for i in range(self.num_classes)}
        for idx, label in enumerate(labels):
            self.class_indices[int(label)].append(idx)

        # 所有数据的索引
        self.all_indices = list(range(len(labels)))

    def __iter__(self):
        all_indices = self.all_indices.copy()
        random.shuffle(all_indices)

        batches = []
        while len(all_indices) >= self.batch_size:
            batch = []
            used_indices = set()

            for cls in range(self.num_classes):
                candidates = [idx for idx in all_indices if idx not in used_indices and self.labels[idx] == cls]
                if len(candidates) >= self.min_samples_per_class:
                    selected = random.sample(candidates, self.min_samples_per_class)
                # elif len(candidates) > 0:
                #     selected = candidates.copy()  # 拿所有可用的
                else:
                    selected = random.sample(self.class_indices[cls], self.min_samples_per_class)

                batch.extend(selected)
                used_indices.update(selected)
                # self.used.extend(selected)

            # 补齐 batch_size
            if len(batch) < self.batch_size:
                remaining = list(set(all_indices) - used_indices)
                if len(remaining) >= self.batch_size - len(batch):
                    extra = random.sample(remaining, self.batch_size - len(batch))
                else:
                    extra = random.choices(remaining, k=self.batch_size - len(batch))  # 不够才重复
                batch.extend(extra)
                used_indices.update(extra)
                # self.used.extend(extra)

            random.shuffle(batch)
            batches.append(batch)

            # 移除本batch使用过的元素
            all_indices = list(set(all_indices) - used_indices)

        # 最后如果还有剩余，组成一个小batch
        if len(all_indices) > 0:
            random.shuffle(all_indices)
            batches.append(all_indices)
        # print('batches', len(batches),len(self.used))
        # flatten
        flat_batches = [idx for batch in batches for idx in batch]
        return iter(flat_batches)

    def __len__(self):
        return len(self.labels)

class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size, num_classes=4, min_samples_per_class=2):
        self.labels = labels
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.min_samples_per_class = min_samples_per_class
        
        # 分类别索引
        self.class_indices = {i: [] for i in range(self.num_classes)}
        for idx, label in enumerate(labels):
            self.class_indices[int(label)].append(idx)

        # 为了保证所有数据出现至少一次
        self.all_indices = list(range(len(labels)))

    def __iter__(self):
        all_indices = self.all_indices.copy()
        random.shuffle(all_indices)
        
        batches = []
        while len(all_indices) >= self.batch_size:
            batch = []
            for cls in range(self.num_classes):
                candidates = [idx for idx in all_indices if self.labels[idx] == cls]
                if len(candidates) >= self.min_samples_per_class:
                    selected = random.sample(candidates, self.min_samples_per_class)
                elif len(candidates) > 0:
                    selected = random.choices(candidates, k=self.min_samples_per_class)  # 重复选择
                else:
                    selected = []
                batch.extend(selected)

            # 填充batch到batch_size
            if len(batch) < self.batch_size:
                # extra = random.sample(all_indices, self.batch_size - len(batch))
                if len(all_indices) >= self.batch_size - len(batch):
                    extra = random.sample(all_indices, self.batch_size - len(batch))
                else:
                    extra = random.choices(all_indices, k=self.batch_size - len(batch))  # 允许重复
                batch.extend(extra)

            random.shuffle(batch)
            batches.append(batch)

            # 移除已用的数据
            for idx in batch:
                if idx in all_indices:
                    all_indices.remove(idx)

        # 最后如果还有剩余，不够一个batch，也输出
        if len(all_indices) > 0:
            batches.append(all_indices)

        # flatten
        flat_batches = [idx for batch in batches for idx in batch]
        return iter(flat_batches)

    def __len__(self):
        return len(self.labels)
