
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import numpy as np
import warnings
from misc import utils
import pandas as pd
import pickle
import time
warnings.filterwarnings("ignore")

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
MEANS = {'cifar': [0.4914, 0.4822, 0.4465], 'imagenet': [0.485, 0.456, 0.406]}
STDS = {'cifar': [0.2023, 0.1994, 0.2010], 'imagenet': [0.229, 0.224, 0.225]}
MEANS['cifar10'] = MEANS['cifar']
STDS['cifar10'] = STDS['cifar']
MEANS['cifar100'] = MEANS['cifar']
STDS['cifar100'] = STDS['cifar']
MEANS['svhn'] = [0.4377, 0.4438, 0.4728]
STDS['svhn'] = [0.1980, 0.2010, 0.1970]
MEANS['mnist'] = [0.1307]
STDS['mnist'] = [0.3081]
MEANS['fashion'] = [0.2861]
STDS['fashion'] = [0.3530]

from torchvision.datasets import ImageFolder as TorchvisionImageFolder
import random
from collections import defaultdict


class InMemoryImageFolder(TorchvisionImageFolder):
    def __init__(self, root, transform=None, random_sample=False, ipc=None, load_in_memory=True):
        """
        Args:
            root (string): 数据集根目录的路径。
            transform (callable, optional): 应用于样本的转换。
            random_sample (bool): 是否对每个类别进行随机采样。
            ipc (int, optional): 如果 random_sample 为 True, 指定每个类别采样的图片数量。
        """
        super().__init__(root)

        # 如果启用随机采样，则在加载图片前对 self.samples 和 self.targets 进行筛选
        if random_sample and ipc is not None:
            # print(f"为每个类别随机采样 {ipc} 张图片...")
            indices_by_class = defaultdict(list)
            # self.targets 是 super().__init__ 创建好的列表，包含每个样本的类别索引
            for i, target in enumerate(self.targets):
                indices_by_class[target].append(i)

            subset_indices = []
            for class_idx in sorted(indices_by_class.keys()):
                # 从当前类别的索引列表中随机采样
                class_indices = indices_by_class[class_idx]
                num_samples = min(ipc, len(class_indices))
                subset_indices.extend(random.sample(class_indices, num_samples))

            # 根据采样后的索引，更新 samples 和 targets 列表
            self.samples = [self.samples[i] for i in subset_indices]
            self.targets = [self.targets[i] for i in subset_indices]
        self.load_in_memory = load_in_memory
        if load_in_memory:
            self.in_memory_images = self._load_pil_images()
        self.transform = transform

    def _load_pil_images(self):
        pil_images = []
        for path, _ in self.samples:
            pil_images.append(self.loader(path).convert('RGB'))
        return pil_images

    def __getitem__(self, index):
        if self.load_in_memory:
            sample = self.in_memory_images[index]
            target = self.targets[index]
        else:
            path, target = self.samples[index]
            sample = self.loader(path).convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target


class ImageFolder(datasets.DatasetFolder):
    """Dataset class for loading subsets with specified IPC.
    """
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 loader=datasets.folder.default_loader,
                 is_valid_file=None,
                 load_memory=False,
                 load_transform=None,
                 nclass=100,
                 phase=0,
                 slct_type='random',
                 ipc=-1,
                 seed=-1,
                 spec='none',
                 return_origin=False,
                 return_path=False,
                 mode_id_file=None):
        self.extensions = IMG_EXTENSIONS if is_valid_file is None else None
        super(ImageFolder, self).__init__(root,
                                          loader,
                                          self.extensions,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)

        # Override
        self.spec = spec
        self.return_origin = return_origin
        if nclass < 1000:
            self.classes, self.class_to_idx = self.find_subclasses(nclass=nclass,
                                                                   phase=phase,
                                                                   seed=seed)
        else:
            self.classes, self.class_to_idx = self.find_classes(self.root)
        self.original_labels = self.find_original_classes()
        self.nclass = nclass
        cur_samples  = datasets.folder.make_dataset(self.root, self.class_to_idx, self.extensions,
                                                    is_valid_file)

        if ipc > 0:
            self.samples = self._subset(slct_type=slct_type, ipc=ipc)
        else:
            self.samples = cur_samples
        self.targets = [s[1] for s in self.samples]
        self.original_targets = [self.original_labels[s[1]] for s in self.samples]
        
        self.load_memory = load_memory
        self.load_transform = load_transform
        if self.load_memory:
            self.imgs = self._load_images(load_transform)
        else:
            self.imgs = self.samples
        self.return_path = return_path
        self.mode_id_file = mode_id_file
        if self.mode_id_file is not None:
            self.mode_id_df = pd.read_csv(self.mode_id_file)
            self.mode_id_df = self.mode_id_df.set_index("image_id")
            self.mode_ids = [self.mode_id_df.loc[s[0].split("/")[-1]]["mode_id"] for s in self.samples]

    def find_subclasses(self, nclass=100, phase=0, seed=0):
        """Finds the class folders in a dataset.
        """
        classes = []
        phase = max(0, phase)
        cls_from = nclass * phase
        cls_to = nclass * (phase + 1)
        if seed == 0:
            if self.spec == 'woof':
                file_list = 'misc/class_woof.txt'
            elif self.spec == 'nette':
                file_list = 'misc/class_nette.txt'
            elif self.spec == 'imagenet100':
                file_list = 'misc/class100.txt'
            elif self.spec == 'imagenet1k':
                file_list = 'misc/class_indices.txt'
            elif self.spec == 'IDC':
                file_list = 'misc/class_IDC.txt'
            elif self.spec == 'imageA':
                file_list = 'misc/imagenet-a.txt'
            elif self.spec == 'imageB':
                file_list = 'misc/imagenet-b.txt'
            elif self.spec == 'imageC':
                file_list = 'misc/imagenet-c.txt'
            elif self.spec == 'imageD':
                file_list = 'misc/imagenet-d.txt'
            elif self.spec == 'imageE':
                file_list = 'misc/imagenet-e.txt'
            else:
                raise AssertionError(f'spec does not exist!')
            with open(file_list, 'r') as f:
                class_name = f.readlines()
            for c in class_name:
                c = c.split('\n')[0]
                classes.append(c)
            classes = classes[cls_from:cls_to]
        else:
            np.random.seed(seed)
            class_indices = np.random.permutation(len(self.classes))[cls_from:cls_to]
            for i in class_indices:
                classes.append(self.classes[i])

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        assert len(classes) == nclass

        return classes, class_to_idx

    def find_original_classes(self):
        all_classes = sorted(os.listdir(self.root))
        # print("[DEBUG] Dataset classes:", all_classes)  # 实际数据集类别
        # print("[DEBUG] Expected classes:", self.classes)  # 程序期望的类别
        original_labels = []
        for class_name in self.classes:
            original_labels.append(all_classes.index(class_name))
        return original_labels

    def _subset(self, slct_type='random', ipc=10):
        n = len(self.samples)
        idx_class = [[] for _ in range(self.nclass)]
        for i in range(n):
            label = self.samples[i][1]
            idx_class[label].append(i)

        min_class = np.array([len(idx_class[c]) for c in range(self.nclass)]).min()
        # print("# examples in the smallest class: ", min_class)
        assert ipc <= min_class

        if slct_type == 'random':
            indices = np.arange(n)
        else:
            raise AssertionError(f'selection type does not exist!')

        samples_subset = []
        idx_class_slct = [[] for _ in range(self.nclass)]
        for i in indices:
            label = self.samples[i][1]
            if len(idx_class_slct[label]) < ipc:
                idx_class_slct[label].append(i)
                samples_subset.append(self.samples[i])

            if len(samples_subset) == ipc * self.nclass:
                break

        return samples_subset

    def _load_images(self, transform=None):
        """Load images on memory
        """
        imgs = []
        for i, (path, _) in enumerate(self.samples):
            sample = self.loader(path)
            if transform != None:
                sample = transform(sample)
            imgs.append(sample)
            # if i % 100 == 0:
            #     print(f"Image loading.. {i}/{len(self.samples)}", end='\r')

        print(" " * 50, end='\r')
        return imgs


    def __getitem__(self, index):
        if not self.load_memory:
            path = self.samples[index][0]
            sample = self.loader(path)
            image_id = path.split("/")[-1]
        else:
            sample = self.imgs[index]
            

        target = self.targets[index]
        original_target = self.original_targets[index]
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
            original_target = self.target_transform(original_target)
            
        if self.mode_id_file is not None:
            if not self.load_memory:
                mode_id = self.mode_id_df.loc[image_id]['mode_id']
            else:
                mode_id = self.mode_ids[index]
            # Return original labels for DiT generation
            if self.return_origin:
                if self.return_path:
                    return sample, target, original_target, mode_id, path
                return sample, target, original_target, mode_id
            else:
                return sample, target, mode_id

        # Return original labels for DiT generation
        if self.return_origin:
            if self.return_path:
                return sample, target, original_target, path
            return sample, target, original_target
        else:
            return sample, target


def transform_cifar(augment=False, from_tensor=False, normalize=True):
    if not augment:
        aug = []
    else:
        aug = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
        print("Dataset with basic Cifar augmentation")

    if from_tensor:
        cast = []
    else:
        cast = [transforms.ToTensor()]

    if normalize:
        normal_fn = [transforms.Normalize(mean=MEANS['cifar'], std=STDS['cifar'])]
    else:
        normal_fn = []

    train_transform = transforms.Compose(cast + aug + normal_fn)
    test_transform = transforms.Compose(cast + normal_fn)

    return train_transform, test_transform


def transform_svhn(augment=False, from_tensor=False, normalize=True):
    if not augment:
        aug = []
    else:
        aug = [transforms.RandomCrop(32, padding=4)]
        print("Dataset with basic SVHN augmentation")

    if from_tensor:
        cast = []
    else:
        cast = [transforms.ToTensor()]

    if normalize:
        normal_fn = [transforms.Normalize(mean=MEANS['svhn'], std=STDS['svhn'])]
    else:
        normal_fn = []

    train_transform = transforms.Compose(cast + aug + normal_fn)
    test_transform = transforms.Compose(cast + normal_fn)

    return train_transform, test_transform


def transform_mnist(augment=False, from_tensor=False, normalize=True):
    if not augment:
        aug = []
    else:
        aug = [transforms.RandomCrop(28, padding=4)]
        print("Dataset with basic MNIST augmentation")

    if from_tensor:
        cast = []
    else:
        cast = [transforms.ToTensor()]

    if normalize:
        normal_fn = [transforms.Normalize(mean=MEANS['mnist'], std=STDS['mnist'])]
    else:
        normal_fn = []

    train_transform = transforms.Compose(cast + aug + normal_fn)
    test_transform = transforms.Compose(cast + normal_fn)

    return train_transform, test_transform


def transform_fashion(augment=False, from_tensor=False, normalize=True):
    if not augment:
        aug = []
    else:
        aug = [transforms.RandomCrop(28, padding=4)]
        print("Dataset with basic FashionMNIST augmentation")

    if from_tensor:
        cast = []
    else:
        cast = [transforms.ToTensor()]

    if normalize:
        normal_fn = [transforms.Normalize(mean=MEANS['fashion'], std=STDS['fashion'])]
    else:
        normal_fn = []

    train_transform = transforms.Compose(cast + aug + normal_fn)
    test_transform = transforms.Compose(cast + normal_fn)

    return train_transform, test_transform


def transform_imagenet(size=-1,
                       augment=False,
                       from_tensor=False,
                       normalize=True,
                       rrc=True,
                       rrc_size=-1):
    if size > 0:
        resize_train = [transforms.Resize(size), transforms.CenterCrop(size)]
        resize_test = [transforms.Resize(size), transforms.CenterCrop(size)]
        # print(f"Resize and crop training images to {size}")
    elif size == 0:
        resize_train = []
        resize_test = []
        assert rrc_size > 0, "Set RRC size!"
    else:
        resize_train = [transforms.RandomResizedCrop(224)]
        resize_test = [transforms.Resize(256), transforms.CenterCrop(224)]

    if not augment:
        aug = []
        # print("Loader with DSA augmentation")
    else:
        jittering = utils.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[
                                      [-0.5675, 0.7192, 0.4009],
                                      [-0.5808, -0.0045, -0.8140],
                                      [-0.5836, -0.6948, 0.4203],
                                  ])
        aug = [transforms.RandomHorizontalFlip(), jittering, lighting]

        if rrc and size >= 0:
            if rrc_size == -1:
                rrc_size = size
            rrc_fn = transforms.RandomResizedCrop(rrc_size, scale=(0.5, 1.0))
            aug = [rrc_fn] + aug
            # print("Dataset with basic imagenet augmentation and RRC")
        else:
            print("Dataset with basic imagenet augmentation")

    if from_tensor:
        cast = []
    else:
        cast = [transforms.ToTensor()]

    if normalize:
        normal_fn = [transforms.Normalize(mean=MEANS['imagenet'], std=STDS['imagenet'])]
    else:
        normal_fn = []

    train_transform = transforms.Compose(resize_train + cast + aug + normal_fn)
    test_transform = transforms.Compose(resize_test + cast + normal_fn)

    return train_transform, test_transform


class _RepeatSampler(object):
    """ Sampler that repeats forever.
    Args:
        sampler (Sampler)
    """
    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)

    def __len__(self):
        return len(self.sampler)


class ClassBatchSampler(object):
    """Intra-class batch sampler 
    """
    def __init__(self, cls_idx, batch_size, drop_last=True):
        self.samplers = []
        for indices in cls_idx:
            n_ex = len(indices)
            sampler = torch.utils.data.SubsetRandomSampler(indices)
            batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                          batch_size=min(n_ex, batch_size),
                                                          drop_last=drop_last)
            self.samplers.append(iter(_RepeatSampler(batch_sampler)))

    def __iter__(self):
        while True:
            for sampler in self.samplers:
                yield next(sampler)

    def __len__(self):
        return len(self.samplers)


class MultiEpochsDataLoader(torch.utils.data.DataLoader):
    """Multi epochs data loader
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()  # Init iterator and sampler once

        self.convert = None
        if self.dataset[0][0].dtype == torch.uint8:
            self.convert = transforms.ConvertImageDtype(torch.float)

        if self.dataset[0][0].device == torch.device('cpu'):
            self.device = 'cpu'
        else:
            self.device = 'cuda'

    def __len__(self):
        return len(self.batch_sampler)

    def __iter__(self):
        for i in range(len(self)):
            data, target = next(self.iterator)
            if self.convert != None:
                data = self.convert(data)
            yield data, target


class ClassDataLoader(MultiEpochsDataLoader):
    """Basic class loader (might be slow for processing data)
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.nclass = self.dataset.nclass
        self.cls_idx = [[] for _ in range(self.nclass)]
        for i in range(len(self.dataset)):
            self.cls_idx[self.dataset.targets[i]].append(i)
        self.class_sampler = ClassBatchSampler(self.cls_idx, self.batch_size, drop_last=True)

        self.cls_targets = torch.tensor([np.ones(self.batch_size) * c for c in range(self.nclass)],
                                        dtype=torch.long,
                                        requires_grad=False,
                                        device='cuda')

    def class_sample(self, c, ipc=-1):
        if ipc > 0:
            indices = self.cls_idx[c][:ipc]
        else:
            indices = next(self.class_sampler.samplers[c])

        data = torch.stack([self.dataset[i][0] for i in indices])
        target = torch.tensor([self.dataset.targets[i] for i in indices])
        return data.cuda(), target.cuda()

    def sample(self):
        data, target = next(self.iterator)
        if self.convert != None:
            data = self.convert(data)

        return data.cuda(), target.cuda()


class ClassMemDataLoader():
    """Class loader with data on GPUs
    """
    def __init__(self, dataset, batch_size, drop_last=False, device='cuda'):
        self.device = device
        self.batch_size = batch_size

        self.dataset = dataset
        self.data = [d[0].to(device) for d in dataset]  # uint8 data
        self.targets = torch.tensor(dataset.targets, dtype=torch.long, device=device)

        sampler = torch.utils.data.SubsetRandomSampler([i for i in range(len(dataset))])
        self.batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                           batch_size=batch_size,
                                                           drop_last=drop_last)
        self.iterator = iter(_RepeatSampler(self.batch_sampler))

        self.nclass = dataset.nclass
        self.cls_idx = [[] for _ in range(self.nclass)]
        for i in range(len(dataset)):
            self.cls_idx[self.targets[i]].append(i)
        self.class_sampler = ClassBatchSampler(self.cls_idx, self.batch_size, drop_last=True)
        self.cls_targets = torch.tensor([np.ones(batch_size) * c for c in range(self.nclass)],
                                        dtype=torch.long,
                                        requires_grad=False,
                                        device=self.device)

        self.convert = None
        if self.data[0].dtype == torch.uint8:
            self.convert = transforms.ConvertImageDtype(torch.float)

    def class_sample(self, c, ipc=-1):
        if ipc > 0:
            indices = self.cls_idx[c][:ipc]
        else:
            indices = next(self.class_sampler.samplers[c])

        data = torch.stack([self.data[i] for i in indices])
        if self.convert != None:
            data = self.convert(data)

        # print(self.targets[indices])
        return data, self.cls_targets[c]

    def sample(self):
        indices = next(self.iterator)
        data = torch.stack([self.data[i] for i in indices])
        if self.convert != None:
            data = self.convert(data)
        target = self.targets[indices]

        return data, target

    def __len__(self):
        return len(self.batch_sampler)

    def __iter__(self):
        for _ in range(len(self)):
            data, target = self.sample()
            yield data, target


class ClassPartMemDataLoader(MultiEpochsDataLoader):
    """Class loader for ImageNet-100 with multi-processing.
       This loader loads target subclass samples on GPUs
       while can loading full training data from storage. 
    """
    def __init__(self, subclass_list, real_to_idx, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.nclass = self.dataset.nclass
        self.mem_cls = subclass_list
        self.real_to_idx = real_to_idx

        self.cls_idx = [[] for _ in range(self.nclass)]
        idx = 0
        self.data_mem = []
        print("Load target class data on memory..")
        for i in range(len(self.dataset)):
            c = self.dataset.targets[i]
            if c in self.mem_cls:
                self.data_mem.append(self.dataset[i][0].cuda())
                self.cls_idx[c].append(idx)
                idx += 1

        if self.data_mem[0].dtype == torch.uint8:
            self.convert = transforms.ConvertImageDtype(torch.float)
        print(f"Subclass: {subclass_list}, {len(self.data_mem)}")

        class_batch_size = 64
        self.class_sampler = ClassBatchSampler([self.cls_idx[c] for c in subclass_list],
                                               class_batch_size,
                                               drop_last=True)
        self.cls_targets = torch.tensor([np.ones(class_batch_size) * c for c in range(self.nclass)],
                                        dtype=torch.long,
                                        requires_grad=False,
                                        device='cuda')

    def class_sample(self, c, ipc=-1):
        if ipc > 0:
            indices = self.cls_idx[c][:ipc]
        else:
            idx = self.real_to_idx[c]
            indices = next(self.class_sampler.samplers[idx])

        data = torch.stack([self.data_mem[i] for i in indices])
        if self.convert != None:
            data = self.convert(data)

        # print([self.dataset.targets[i] for i in self.slct[indices]])
        return data, self.cls_targets[c]

    def sample(self):
        data, target = next(self.iterator)
        if self.convert != None:
            data = self.convert(data)

        return data.cuda(), target.cuda()


def load_data(args, tsne=False,detailed=True):
    """Load training and soft_train data
    """
    start_time = time.time()
    if args.dataset.startswith('cifar'):
        train_transform, test_transform = transform_cifar(augment=args.augment)

        if args.dataset == 'cifar100':
            train_dataset = datasets.CIFAR100(args.data_dir, train=True, transform=train_transform)
            val_dataset = datasets.CIFAR100(args.data_dir, train=False, transform=test_transform)
            nclass = 100
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'svhn':
        train_transform, test_transform = transform_svhn(augment=args.augment)

        train_dataset = datasets.SVHN(os.path.join(args.data_dir, 'svhn'),
                                      split='train',
                                      download=False,
                                      transform=train_transform)
        val_dataset = datasets.SVHN(os.path.join(args.data_dir, 'svhn'),
                                    split='test',
                                    download=False,
                                    transform=test_transform)
        nclass = 10

    elif args.dataset == 'fashion':
        train_transform, test_transform = transform_fashion(augment=args.augment)

        train_dataset = datasets.FashionMNIST(args.data_dir, train=True, transform=train_transform)
        val_dataset = datasets.FashionMNIST(args.data_dir, train=False, transform=test_transform)
        nclass = 10

    elif args.dataset == 'mnist':
        train_transform, test_transform = transform_mnist(augment=args.augment)

        train_dataset = datasets.MNIST(args.data_dir, train=True, transform=train_transform)
        val_dataset = datasets.MNIST(args.data_dir, train=False, transform=test_transform)
        nclass = 10

    elif args.dataset == 'imagenet':
        val_subdir = 'validation'
        if len(args.imagenet_dir) == 1:
            traindir = os.path.join(args.imagenet_dir[0], 'train')
            valdir = os.path.join(args.imagenet_dir[0], val_subdir)
        else:
            traindir = args.imagenet_dir[0]
            valdir = os.path.join(args.imagenet_dir[1], val_subdir)

        # print(f"{args.augment} {args.size} {args.nclass} {args.dseed} {args.load_memory}")
        train_transform, test_transform = transform_imagenet(augment=args.augment,
                                                             size=args.size,
                                                             from_tensor=False)
        if args.nclass<=20 and args.size <= 256:
            args.load_memory = True
        train_dataset = ImageFolder(traindir,
                                    train_transform,
                                    nclass=args.nclass,
                                    seed=args.dseed,
                                    slct_type=args.slct_type,
                                    # ipc=args.ipc,
                                    load_memory=args.load_memory,
                                    spec=args.spec)
        val_dataset = ImageFolder(valdir,
                                  test_transform,
                                  nclass=args.nclass,
                                  seed=args.dseed,
                                  load_memory=args.load_memory,
                                  spec=args.spec)

        nclass = len(train_dataset.classes)
        assert nclass == len(val_dataset.classes)
        for i in range(len(train_dataset.classes)):
            assert train_dataset.classes[i] == val_dataset.classes[i]
        assert np.array(train_dataset.targets).max() == nclass - 1
        assert np.array(val_dataset.targets).max() == nclass - 1

        start_time = time.time()
        if detailed:
            print("Subclass is extracted: ")
            print(" #class: ", nclass)
            print(" #train: ", len(train_dataset.targets))
        if args.ipc > 0 and detailed:
            print(f"  => subsample ({args.slct_type} ipc {args.ipc})")
        if detailed:
            print(" #valid: ", len(val_dataset.targets))

    elif args.dataset == 'places365':
        transform_train = transforms.Compose([
            transforms.Resize((args.size, args.size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        transform_val = transforms.Compose([
            transforms.Resize((args.size, args.size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # 2. 构建数据集的路径
        val_dir = os.path.join(args.places365dir, 'val')
        if args.places365_random:
            train_dir = os.path.join(args.places365dir, 'train')
        else:
            train_dir = args.train_dataset_dir

        if args.nclass <= 10:
            load_in_memory = True
        else:
            load_in_memory = False
        train_dataset = InMemoryImageFolder(
            root=train_dir,
            transform=transform_train,
            random_sample=args.places365_random,  # 传入是否采样的标志
            ipc=args.ipc,  # 传入每个类别的样本数
            load_in_memory=load_in_memory,
        )

        val_dataset = InMemoryImageFolder(root=val_dir, transform=transform_val, load_in_memory=load_in_memory)
        print(f"Load {args.nclass} classes from {train_dir} to train")
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    # args.workers = 8
    # 1024时使用梯度累积
    if args.size == 1024:
        args.batch_size = args.batch_size // args.accumulation_steps

    train_loader = MultiEpochsDataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.workers,
                                         persistent_workers=args.workers > 0,
                                         pin_memory=True)
    val_loader = MultiEpochsDataLoader(val_dataset,
                                       batch_size=args.batch_size//2,
                                       shuffle=False,
                                       persistent_workers=True,
                                       num_workers=8,
                                       pin_memory=True)
    return train_dataset, train_loader, val_loader, args.nclass


def load_resized_data(args):
    """Load original training data (fixed spatial size and without augmentation) for condensation
    """
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(args.data_dir, train=True, transform=transforms.ToTensor())
        normalize = transforms.Normalize(mean=MEANS['cifar10'], std=STDS['cifar10'])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])
        val_dataset = datasets.CIFAR10(args.data_dir, train=False, transform=transform_test)
        train_dataset.nclass = 10

    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(args.data_dir,
                                          train=True,
                                          transform=transforms.ToTensor())

        normalize = transforms.Normalize(mean=MEANS['cifar100'], std=STDS['cifar100'])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])
        val_dataset = datasets.CIFAR100(args.data_dir, train=False, transform=transform_test)
        train_dataset.nclass = 100

    elif args.dataset == 'svhn':
        train_dataset = datasets.SVHN(os.path.join(args.data_dir, 'svhn'),
                                      split='train',
                                      transform=transforms.ToTensor())
        train_dataset.targets = train_dataset.labels

        normalize = transforms.Normalize(mean=MEANS['svhn'], std=STDS['svhn'])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        val_dataset = datasets.SVHN(os.path.join(args.data_dir, 'svhn'),
                                    split='test',
                                    transform=transform_test)
        train_dataset.nclass = 10

    elif args.dataset == 'mnist':
        train_dataset = datasets.MNIST(args.data_dir, train=True, transform=transforms.ToTensor())

        normalize = transforms.Normalize(mean=MEANS['mnist'], std=STDS['mnist'])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        val_dataset = datasets.MNIST(args.data_dir, train=False, transform=transform_test)
        train_dataset.nclass = 10

    elif args.dataset == 'fashion':
        train_dataset = datasets.FashionMNIST(args.data_dir,
                                              train=True,
                                              transform=transforms.ToTensor())

        normalize = transforms.Normalize(mean=MEANS['fashion'], std=STDS['fashion'])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        val_dataset = datasets.FashionMNIST(args.data_dir, train=False, transform=transform_test)
        train_dataset.nclass = 10

    elif args.dataset == 'imagenet':
        traindir = os.path.join(args.imagenet_dir, 'train')
        valdir = os.path.join(args.imagenet_dir, 'val')

        # We preprocess images to the fixed size (default: 224)
        resize = transforms.Compose([
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.PILToTensor()
        ])

        if args.load_memory:  # uint8
            transform = None
            load_transform = resize
        else:
            transform = transforms.Compose([resize, transforms.ConvertImageDtype(torch.float)])
            load_transform = None

        _, test_transform = transform_imagenet(size=args.size)
        train_dataset = ImageFolder(traindir,
                                    transform=transform,
                                    nclass=args.nclass,
                                    phase=args.phase,
                                    seed=args.dseed,
                                    load_memory=args.load_memory,
                                    load_transform=load_transform)
        val_dataset = ImageFolder(valdir,
                                  test_transform,
                                  nclass=args.nclass,
                                  phase=args.phase,
                                  seed=args.dseed,
                                  load_memory=False)

    val_loader = MultiEpochsDataLoader(val_dataset,
                                       batch_size=args.batch_size // 2,
                                       shuffle=False,
                                       persistent_workers=True,
                                       num_workers=4)

    assert train_dataset[0][0].shape[-1] == val_dataset[0][0].shape[-1]  # width check

    return train_dataset, val_loader


def img_denormlaize(img, dataname='imagenet'):
    """Scaling and shift a batch of images (NCHW)
    """
    mean = MEANS[dataname]
    std = STDS[dataname]
    nch = img.shape[1]

    mean = torch.tensor(mean, device=img.device).reshape(1, nch, 1, 1)
    std = torch.tensor(std, device=img.device).reshape(1, nch, 1, 1)

    return img * std + mean


def save_img(save_dir, img, unnormalize=True, max_num=200, size=64, nrow=10, dataname='imagenet'):
    img = img[:max_num].detach()
    if unnormalize:
        img = img_denormlaize(img, dataname=dataname)
    img = torch.clamp(img, min=0., max=1.)

    if img.shape[-1] > size:
        img = F.interpolate(img, size)
    save_image(img.cpu(), save_dir, nrow=nrow)


if __name__ == '__main__':
    from argument import args

    traindir = os.path.join(args.imagenet_dir, 'train')
    train_transform, test_transform = transform_imagenet(augment=False,
                                                         from_tensor=False,
                                                         size=args.size,
                                                         rrc=False,
                                                         normalize=False)
    train_dataset = ImageFolder(traindir,
                                train_transform,
                                nclass=args.nclass,
                                seed=args.dseed,
                                slct_type=args.slct_type,
                                ipc=args.ipc,
                                load_memory=args.load_memory)
    loader = ClassDataLoader(train_dataset,
                             batch_size=args.batch_real,
                             num_workers=args.workers,
                             shuffle=True,
                             pin_memory=True,
                             drop_last=True)
    data = []
    for c in range(args.nclass):
        img, _ = loader.class_sample(c, args.ipc)
        data.append(img)
    data = torch.cat(data)
    print(data.shape)
    torch.save(data, "./results/samples/init/data.pt")
    print("image saved!")


class KNNDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_file, transform, spec='woof'):
        with open(dataset_file, 'rb') as f:
            self.data = pickle.load(f)
        nclass = 10
        phase = 0
        seed = 0
        self.spec = spec
        self.classes, self.class_to_idx = self.find_subclasses(nclass=nclass,
                                                                   phase=phase,
                                                                   seed=seed)

        self.loader = datasets.folder.default_loader
        self.transform = transform
        

    def __getitem__(self, index):
        paths = self.data['paths'][index]
        label = int(self.data['labels'][index])
        label_name = paths[0].split("/")[-2]
        new_label = self.class_to_idx[label_name]
        images = list()
        for p in paths:
            images.append(self.transform(self.loader(p)))
        images = torch.stack(images)
        return images, new_label, label

    def __len__(self):
        return len(self.data['paths'])
        
        




    def find_subclasses(self, nclass=100, phase=0, seed=0):
        """Finds the class folders in a dataset.
        """
        classes = []
        phase = max(0, phase)
        cls_from = nclass * phase
        cls_to = nclass * (phase + 1)
        if seed == 0:
            if self.spec == 'woof':
                file_list = 'misc/class_woof.txt'
            elif self.spec == 'nette':
                file_list = 'misc/class_nette.txt'
            elif self.spec == 'imagenet100':
                file_list = 'misc/class100.txt'
            elif self.spec == 'imagenet1k':
                file_list = 'misc/class_indices.txt'
            elif args.spec == 'IDC':
                file_list = 'misc/class_IPC.txt'
            else:
                raise AssertionError(f'spec does not exist!')
            with open(file_list, 'r') as f:
                class_name = f.readlines()
            for c in class_name:
                c = c.split('\n')[0]
                classes.append(c)
            classes = classes[cls_from:cls_to]
        else:
            np.random.seed(seed)
            class_indices = np.random.permutation(len(self.classes))[cls_from:cls_to]
            for i in class_indices:
                classes.append(self.classes[i])

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        assert len(classes) == nclass

        return classes, class_to_idx


# class ModeDataset(torch.utils.data.Dataset):

#     def __init__(self, imagenet_dir, dataset_file, transform, iterations=100, spec='woof'):
#         with open(dataset_file, 'rb') as f:
#             self.data = pickle.load(f)
#         nclass = 10
#         phase = 0
#         seed = 0
#         self.spec = spec
#         self.classes, self.class_to_idx = self.find_subclasses(nclass=nclass,
#                                                                    phase=phase,
#                                                                    seed=seed)

#         self.loader = datasets.folder.default_loader
#         self.transform = transform
#         self.iterations = iterations
#         self.classes_ids = list(self.data.keys())
#         self.imagenet_dir = imagenet_dir
        

#     def __getitem__(self, index):
#         class_id = np.random.choice(self.classes_ids)
#         return self.get_batch_by_class(class_id)
    
#     def get_batch_by_class(self, class_id):
#         label_name = self.data[class_id][0][0].split("/")[-2]
#         new_label = self.class_to_idx[label_name]
#         dataset1 = list()
#         for i in range(10):
#             mode_id = i
#             img_path = np.random.choice(self.data[class_id][mode_id])
#             split_path = img_path.split("/")
#             img_path = os.path.join(self.imagenet_dir, split_path[-2], split_path[-1])
#             img = self.loader(img_path)
#             img = self.transform(img)
#             dataset1.append(img)
#         dataset1 = torch.stack(dataset1)
#         dataset2 = list()
#         for i in range(10):
#             mode_id = i
#             img_path = np.random.choice(self.data[class_id][mode_id])
#             split_path = img_path.split("/")
#             img_path = os.path.join(self.imagenet_dir, split_path[-2], split_path[-1])
#             img = self.loader(img_path)
#             img = self.transform(img)
#             dataset2.append(img)
#         dataset2 = torch.stack(dataset2)
#         return dataset1, dataset2, torch.tensor([new_label] * 10), torch.tensor([class_id] * 10)



#     def __len__(self):
#         return self.iterations

#     def find_subclasses(self, nclass=100, phase=0, seed=0):
#         """Finds the class folders in a dataset.
#         """
#         classes = []
#         phase = max(0, phase)
#         cls_from = nclass * phase
#         cls_to = nclass * (phase + 1)
#         if seed == 0:
#             if self.spec == 'woof':
#                 file_list = './misc/class_woof.txt'
#             elif self.spec == 'nette':
#                 file_list = './misc/class_nette.txt'
#             else:
#                 file_list = './misc/class100.txt'
#             with open(file_list, 'r') as f:
#                 class_name = f.readlines()
#             for c in class_name:
#                 c = c.split('\n')[0]
#                 classes.append(c)
#             classes = classes[cls_from:cls_to]
#         else:
#             np.random.seed(seed)
#             class_indices = np.random.permutation(len(self.classes))[cls_from:cls_to]
#             for i in class_indices:
#                 classes.append(self.classes[i])

#         class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
#         assert len(classes) == nclass

#         return classes, class_to_idx


class ModeDataset(torch.utils.data.Dataset):

    def __init__(self, imagenet_dir, dataset_file, transform, iterations=100, spec='woof'):
        with open(dataset_file, 'rb') as f:
            self.data = pickle.load(f)
        nclass = 10
        phase = 0
        seed = 0
        self.spec = spec
        self.classes, self.class_to_idx = self.find_subclasses(nclass=nclass,
                                                                   phase=phase,
                                                                   seed=seed)

        self.loader = datasets.folder.default_loader
        self.transform = transform
        self.iterations = iterations
        self.classes_ids = list(self.data.keys())
        self.imagenet_dir = imagenet_dir
        

    def __getitem__(self, index):
        class_id = np.random.choice(self.classes_ids)
        return self.get_batch_by_class(class_id)
    
    def get_batch_by_class(self, class_id, shuffle=True):
        label_name = self.data[class_id][0][0].split("/")[-2]
        new_label = self.class_to_idx[label_name]
        dataset1 = list()
        for i in range(10):
            mode_id = i
            if shuffle:
                
                img_path = np.random.choice(self.data[class_id][mode_id])
            else:
                # print(f"class_id: {class_id} - mode_id: {mode_id} -", len(self.data[class_id][mode_id]))
                img_path = self.data[class_id][mode_id][0]
            split_path = img_path.split("/")
            img_path = os.path.join(self.imagenet_dir, split_path[-2], split_path[-1])
            img = self.loader(img_path)
            img = self.transform(img)
            dataset1.append(img)
        dataset1 = torch.stack(dataset1)
        dataset2 = list()
        for i in range(10):
            mode_id = i
            if shuffle:
                img_path = np.random.choice(self.data[class_id][mode_id])
            else:
                # print(f"class_id: {class_id} - mode_id: {mode_id} -", len(self.data[class_id][mode_id]))
                img_path = self.data[class_id][mode_id][0]
            split_path = img_path.split("/")
            img_path = os.path.join(self.imagenet_dir, split_path[-2], split_path[-1])
            img = self.loader(img_path)
            img = self.transform(img)
            dataset2.append(img)
        dataset2 = torch.stack(dataset2)
        return dataset1, dataset2, torch.tensor([new_label] * 10), torch.tensor([class_id] * 10)



    def __len__(self):
        return self.iterations

    def find_subclasses(self, nclass=100, phase=0, seed=0):
        """Finds the class folders in a dataset.
        """
        classes = []
        phase = max(0, phase)
        cls_from = nclass * phase
        cls_to = nclass * (phase + 1)
        if seed == 0:
            if self.spec == 'woof':
                file_list = 'misc/class_woof.txt'
            elif self.spec == 'nette':
                file_list = 'misc/class_nette.txt'
            elif args.spec == 'IDC':
                file_list = 'misc/class_IPC.txt'
            else:
                file_list = 'misc/class100.txt'
            with open(file_list, 'r') as f:
                class_name = f.readlines()
            for c in class_name:
                c = c.split('\n')[0]
                classes.append(c)
            classes = classes[cls_from:cls_to]
        else:
            np.random.seed(seed)
            class_indices = np.random.permutation(len(self.classes))[cls_from:cls_to]
            for i in class_indices:
                classes.append(self.classes[i])

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        assert len(classes) == nclass

        return classes, class_to_idx