import torch
import os
import timm
from timm.data import Mixup, AugMixDataset, create_transform
from timm.data.distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from torchvision import datasets
from PIL import Image
from timm.data.mixup import cutmix_bbox_and_lam
from torch.utils.data import Dataset

class MultiTransformDataset(Dataset):
    """Return multiple views per sample: e.g., [clean_view, aug_view]."""
    def __init__(self, base_dataset, transforms):
        self.base = base_dataset
        self.transforms = transforms  # list[callable]
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        img, target = self.base[idx]
        # img may already be PIL when base.transform=None (recommended)
        views = [t(img.copy()) for t in self.transforms]
        return views, target

class MixupWithIdx(timm.data.mixup.Mixup):
    
    def _mix_batch(self, x):
        lam, use_cutmix = self._params_per_batch()
        bs, _, H, W = x.shape
        cutmix_mask = None
        
        if lam == 1.:
            return 1.
        if use_cutmix:
            (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
                x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
            x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
            cutmix_mask = (yl, yh, xl, xh)
        else:
            x_flipped = x.flip(0).mul_(1. - lam)
            x.mul_(lam).add_(x_flipped)
        return lam, cutmix_mask
    
    def __call__(self, x, y):
        # 1. 取回 lam（与 timm 源码保持一致）
        if self.mode == 'elem':
            lam = self._mix_elem(x)
        elif self.mode == 'pair':
            lam = self._mix_pair(x)
        else:
            lam, cutmix_mask = self._mix_batch(x)

        # 2. 生成 target，与 timm 标准实现一致
        y_mix = timm.data.mixup.mixup_target(
            y, self.num_classes, lam, self.label_smoothing)

        return x, y_mix, lam, cutmix_mask
    
class ImageNet(torch.utils.data.Dataset):
    """move to tools data_set"""
    def __init__(self, root, meta_file='../src_data/val.txt', transform=None): # ./imagenet_val_1k.txt

        self.data_dir = root
        self.meta_file = meta_file
        self.transform = transform
        self._indices = []

        for line in open(meta_file, encoding="utf-8"):
            temp_names=line.strip().split(' ')
            img_path, label = temp_names[0], temp_names[1]
            self._indices.append((os.path.join(self.data_dir, img_path), label))

    def __len__(self): 
        return len(self._indices)

    def __getitem__(self, index):
        img_path, label = self._indices[index]
        img = Image.open(img_path).convert('RGB')
        label = int(label)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

def build_dataset(args, num_aug_splits=0):
    # build dataset
    dataset_train = datasets.ImageFolder(root=args.train_dir, transform=None)
    dataset_eval = datasets.ImageFolder(root=args.eval_dir, transform=None)
    # dataset_eval=ImageNet(root=args.eval_dir)

    # If need "clean + augmented" two splits, create them explicitly
    if getattr(args, 'clean_first_split', False) and args.aug_splits == 2:
        # clean transform (no train-time aug)
        tfm_clean = create_transform(
            args.input_size,
            is_training=True,
            use_prefetcher=False,
            interpolation=args.interpolation,
            mean=args.mean,
            std=args.std,
            crop_pct=args.crop_pct,
        )
        # augmented transform (exactly the same as original training transform when no aug_splits)
        train_interpolation = args.train_interpolation or args.interpolation
        tfm_aug = create_transform(
            args.input_size,
            is_training=True,
            use_prefetcher=False,
            no_aug=args.no_aug,
            scale=args.scale,
            ratio=args.ratio,
            hflip=args.hflip,
            vflip=args.vflip,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=train_interpolation,
            mean=args.mean,
            std=args.std,
            crop_pct=args.crop_pct,
            tf_preprocessing=False,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            re_num_splits=0,   # no split-aware RE; only aug branch uses RE
            separate=False
        )
        dataset_train = MultiTransformDataset(dataset_train, [tfm_clean, tfm_aug])
    else:
        # legacy path: (optional) AugMixDataset with num_aug_splits
        if num_aug_splits > 1:
            dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
        # build single training transform (used by AugMixDataset or single-view)
        train_interpolation = args.train_interpolation
        if args.no_aug or not train_interpolation:
            train_interpolation = args.interpolation
        re_num_splits = 0
        if args.resplit:
            re_num_splits = num_aug_splits or 2
        dataset_train.transform = create_transform(
            args.input_size,
            is_training=True,
            use_prefetcher=False,
            no_aug=args.no_aug,
            scale=args.scale,
            ratio=args.ratio,
            hflip=args.hflip,
            vflip=args.vflip,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=train_interpolation,
            mean=args.mean,
            std=args.std,
            crop_pct=args.crop_pct,
            tf_preprocessing=False,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            re_num_splits=re_num_splits,
            separate=num_aug_splits > 0
        )

    dataset_eval.transform = create_transform(
        args.input_size,
        is_training=False,
        use_prefetcher=False,
        interpolation=args.interpolation,
        mean=args.mean,
        std=args.std,
        crop_pct=args.crop_pct
    )

    # create sampler
    sampler_train = None
    sampler_eval = None
    if args.distributed and not isinstance(dataset_train, torch.utils.data.IterableDataset):
        if args.aug_repeats:
            sampler_train = RepeatAugSampler(dataset_train, num_repeats=args.aug_repeats)
        else:
            sampler_train = torch.utils.data.distributed.DistributedSampler(dataset_train)
    else:
        assert args.aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
    # IMPORTANT: ensure each rank sees the EXACT SAME number of samples/steps during evaluation
    # Using OrderedDistributedSampler can lead to uneven per-rank lengths when len % world_size != 0,
    # which in turn causes mismatched collectives (e.g., all_reduce) on the last batch.
    # Switch to torch DistributedSampler with drop_last=True or padding behavior to keep steps equal.
    if args.distributed and not isinstance(dataset_eval, torch.utils.data.IterableDataset):
        # Keep all evaluation samples: DistributedSampler will pad last few samples
        # so that num_samples per rank are equal and no mismatch occurs.
        sampler_eval = torch.utils.data.distributed.DistributedSampler(
            dataset_eval, shuffle=False, drop_last=False
        )
    else:
        sampler_eval = None

    # create dataloader
    dataloader_train = torch.utils.data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        sampler=sampler_train,
        collate_fn=None,
        pin_memory=args.pin_mem,
        drop_last=True
    )
    dataloader_eval = torch.utils.data.DataLoader(
        dataset=dataset_eval,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        sampler=sampler_eval,
        collate_fn=None,
        pin_memory=args.pin_mem,
        # keep all samples; DistributedSampler will pad to equalize steps
        drop_last=False
    )

    # setup mixup / cutmix
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        #mixup_fn = Mixup(**mixup_args)
        mixup_fn = MixupWithIdx(**mixup_args)

    return dataloader_train, dataloader_eval, mixup_fn
