import os
import random
import time
import warnings 

import PIL.Image as PImage
import numpy as np
import torch
import torchvision
from timm.data import AutoAugment as TimmAutoAugment
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
from timm.data.distributed_sampler import RepeatAugSampler
from timm.data.transforms_factory import transforms_imagenet_eval
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
from torchvision.transforms import AutoAugment as TorchAutoAugment
from torchvision.transforms import transforms, TrivialAugmentWide

try:
    from torchvision.transforms import InterpolationMode
    interpolation = InterpolationMode.BICUBIC
except:
    import PIL
    interpolation = PIL.Image.BICUBIC


def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_size_per_gpu, world_size, global_rank):
    warnings.filterwarnings('ignore', category=UserWarning)
    
    mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
    trans_train = create_transform(
        is_training=True, input_size=img_size,
        auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1,
        mean=mean, std=std,
    )
    if img_size < 384:
        for i, t in enumerate(trans_train.transforms):
            if isinstance(t, (TorchAutoAugment, TimmAutoAugment)):
                trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation)
                break
        trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std)
    else:
        trans_val = transforms.Compose([
            transforms.Resize((img_size, img_size), interpolation=interpolation),
            transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
        ])
    print_transform(trans_train, '[train]')
    print_transform(trans_val, '[val]')
    
    imagenet_folder = os.path.abspath(data_path)
    for postfix in ('train', 'val'):
        if imagenet_folder.endswith(postfix):
            imagenet_folder = imagenet_folder[:-len(postfix)]
    
    # 确保路径是正确的ImageNet根目录，然后拼接 train 和 val 子目录
    train_dir = os.path.join(imagenet_folder, 'train')
    val_dir = os.path.join(imagenet_folder, 'val')

    dataset_train = torchvision.datasets.ImageFolder(train_dir, trans_train)
    dataset_val = torchvision.datasets.ImageFolder(val_dir, trans_val)
    
    # 获取类别数量和总样本数
    num_classes = len(dataset_train.classes) # 假设训练集和验证集类别相同
    total_train_samples = len(dataset_train)
    total_val_samples = len(dataset_val)

    if rep_aug:
        print(f'[数据集] 使用重复增强: 计数={rep_aug}') # Changed to Chinese
        train_sp = RepeatAugSampler(dataset_train, shuffle=True, num_repeats=rep_aug)
    else:
        train_sp = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True, drop_last=True)
    
    loader_train = DataLoader(
        dataset=dataset_train, num_workers=workers, pin_memory=True,
        batch_size=batch_size_per_gpu, sampler=train_sp, persistent_workers=workers > 0,
        worker_init_fn=worker_init_fn,
    )
    iters_train = len(loader_train)
    print(f'[数据集: 训练] bs={world_size}x{batch_size_per_gpu}={world_size * batch_size_per_gpu}, 迭代次数={iters_train}') # Changed to Chinese
    
    val_ratio = 2
    loader_val = DataLoader(
        dataset=dataset_val, num_workers=workers, pin_memory=True,
        batch_sampler=DistInfiniteBatchSampler(world_size, global_rank, len(dataset_val), glb_batch_size=val_ratio * batch_size_per_gpu, filling=False, shuffle=False),
        worker_init_fn=worker_init_fn,
    )
    iters_val = len(loader_val)
    print(f'[数据集: 验证] bs={world_size}x{val_ratio * batch_size_per_gpu}={val_ratio * world_size * batch_size_per_gpu}, 迭代次数={iters_val}') # Changed to Chinese
    
    time.sleep(3)
    warnings.resetwarnings()
    
    # 返回7个值
    return loader_train, iters_train, iter(loader_val), iters_val, num_classes, total_train_samples, total_val_samples


def worker_init_fn(worker_id):
    # see: https://pytorch.org/docs/stable/notes/randomness.html#dataloader
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def print_transform(transform, s):
    print(f'转换 {s} = ') # Changed to Chinese
    for t in transform.transforms:
        print(t)
    print('---------------------------\n')


class DistInfiniteBatchSampler(Sampler):
    def __init__(self, world_size, global_rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True):
        assert glb_batch_size % world_size == 0
        self.world_size, self.rank = world_size, global_rank
        self.dataset_len = dataset_len
        self.glb_batch_size = glb_batch_size
        self.batch_size = glb_batch_size // world_size
        
        self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
        self.filling = filling
        self.shuffle = shuffle
        self.epoch = 0
        self.seed = seed
        self.indices = self.gener_indices()
    
    def gener_indices(self):
        global_max_p = self.iters_per_ep * self.glb_batch_size  # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch + self.seed)
            global_indices = torch.randperm(self.dataset_len, generator=g)
        else:
            global_indices = torch.arange(self.dataset_len)
        filling = global_max_p - global_indices.shape[0]
        if filling > 0 and self.filling:
            global_indices = torch.cat((global_indices, global_indices[:filling]))
        global_indices = tuple(global_indices.numpy().tolist())
        
        seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
        local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
        self.max_p = len(local_indices)
        return local_indices
    
    def __iter__(self):
        self.epoch = 0
        while True:
            self.epoch += 1
            p, q = 0, 0
            while p < self.max_p:
                q = p + self.batch_size
                yield self.indices[p:q]
                p = q
            if self.shuffle:
                self.indices = self.gener_indices()
    
    def __len__(self):
        return self.iters_per_ep