#!/usr/bin/env python3
"""
验证 DistributedBatchSampler 的索引分配逻辑
无需真实分布式环境，模拟多卡行为
"""

import torch
from torch.utils.data import SequentialSampler, Dataset

# 复制你的 DistributedBatchSampler 类（简化版）
class DistributedBatchSampler:
    def __init__(self, sampler, batch_size, drop_last, rank, world_size, interleave=False):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.rank = rank
        self.world_size = world_size
        self.interleave = interleave
        self.start_iter = 0
        self.wrap_around = 0
        
    def __iter__(self):
        batch = []
        i = 0
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                tbatch = self._batch(batch)
                if i >= self.start_iter:
                    yield tbatch
                i += 1
                batch = []
                
    def _batch(self, batch):
        if self.interleave:
            return batch[self.rank : self.batch_size : self.world_size]
        start = self.rank * self.batch_size // self.world_size
        end = (self.rank + 1) * self.batch_size // self.world_size
        return batch[start:end]

# 模拟数据集
class DummyDataset(Dataset):
    def __init__(self, size):
        self.size = size
    def __len__(self):
        return self.size
    def __getitem__(self, idx):
        return idx

def test_sampler_indices():
    """测试不同配置下的索引分配"""
    
    total_samples = 100
    global_batch_size = 16
    world_size = 2
    
    dataset = DummyDataset(total_samples)
    
    print("=" * 80)
    print("测试场景: total_samples=100, global_batch=16, world_size=2")
    print("=" * 80)
    
    # 测试两种 interleave 模式
    for interleave in [False, True]:
        print(f"\n{'='*40}")
        print(f"interleave = {interleave}")
        print(f"{'='*40}")
        
        rank0_indices = []
        rank1_indices = []
        
        # 模拟 rank0
        sampler0 = SequentialSampler(dataset)
        batch_sampler0 = DistributedBatchSampler(
            sampler0, 
            batch_size=global_batch_size,
            drop_last=True,
            rank=0,
            world_size=world_size,
            interleave=interleave
        )
        
        for batch in batch_sampler0:
            rank0_indices.extend(batch)
            if len(rank0_indices) >= 32:  # 只看前2个batch
                break
        
        # 模拟 rank1
        sampler1 = SequentialSampler(dataset)
        batch_sampler1 = DistributedBatchSampler(
            sampler1,
            batch_size=global_batch_size,
            drop_last=True,
            rank=1,
            world_size=world_size,
            interleave=interleave
        )
        
        for batch in batch_sampler1:
            rank1_indices.extend(batch)
            if len(rank1_indices) >= 32:
                break
        
        print(f"\nRank 0 前32个索引: {rank0_indices[:32]}")
        print(f"Rank 1 前32个索引: {rank1_indices[:32]}")
        
        # 关键验证
        all_indices = sorted(rank0_indices[:32] + rank1_indices[:32])
        expected = list(range(64))  # 两个rank * 2个batch * 8样本/rank/batch
        
        print(f"\n✓ 索引类型: {'全局索引' if all_indices == expected else '局部索引或错误'}")
        print(f"✓ 是否有重叠: {len(set(rank0_indices[:32]) & set(rank1_indices[:32])) > 0}")
        print(f"✓ 覆盖范围: {min(all_indices)} ~ {max(all_indices)}")

def test_mask_alignment():
    """测试 mask 对齐场景"""
    
    print("\n" + "=" * 80)
    print("Mask 对齐测试")
    print("=" * 80)
    
    total_samples = 10000
    mask_start = 5000
    mask_length = 1000
    global_batch_size = 1024
    world_size = 2
    
    dataset = DummyDataset(total_samples)
    
    # 模拟两个 rank
    for rank in [0, 1]:
        sampler = SequentialSampler(dataset)
        batch_sampler = DistributedBatchSampler(
            sampler,
            batch_size=global_batch_size,
            drop_last=True,
            rank=rank,
            world_size=world_size,
            interleave=False  # 使用你代码中的默认值
        )
        
        # 找到第一个在 mask 范围内的索引
        first_masked_idx = None
        for batch in batch_sampler:
            for idx in batch:
                if mask_start <= idx < mask_start + mask_length:
                    first_masked_idx = idx
                    break
            if first_masked_idx:
                break
        
        print(f"\nRank {rank}:")
        print(f"  第一个被mask的样本全局索引: {first_masked_idx}")
        print(f"  在 global_mask 中的位置: {first_masked_idx - mask_start if first_masked_idx else 'N/A'}")
        
        # 验证逻辑
        if first_masked_idx:
            mask_offset = first_masked_idx - mask_start
            print(f"  ✓ 应读取 global_mask[{mask_offset}]")

if __name__ == "__main__":
    test_sampler_indices()
    test_mask_alignment()
    
    print("\n" + "=" * 80)
    print("结论:")
    print("=" * 80)
    print("1. DistributedBatchSampler 使用全局索引 ✓")
    print("2. 不同 rank 访问不同的索引子集，无重叠 ✓")
    print("3. 你的 mask 逻辑理论上正确 ✓")
    print("\n注意:")
    print("- interleave=False 时: rank0 拿连续的前半，rank1 拿连续的后半")
    print("- interleave=True 时: rank0 拿偶数索引，rank1 拿奇数索引")
    print("- 检查你的 make_data_loader 中是否显式设置了 interleave 参数")
