
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from avalanche.benchmarks import nc_benchmark
import numpy as np
import os
from PIL import Image

from avalanche.benchmarks.datasets.mini_imagenet.mini_imagenet import \
    MiniImageNetDataset

_default_train_transform = transforms.Compose([
    transforms.ToPILImage(),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

_default_test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])


def SplitMiniImageNet(root_path, n_experiences=20, return_task_id=False, seed=0,
                      fixed_class_order=None,
                      train_transform=_default_train_transform,
                      test_transform=_default_test_transform,
                      preprocessed=False, **kwargs):
    """
    Creates a CL scenario using the Mini ImageNet dataset.
    If the dataset is not present in the computer the method automatically
    download it and store the data in the data folder.

    :param preprocessed: Use preprocessed images for Mini-Imagenet if True, otherwise use original Imagenet.
    :param root_path: Root path of the downloaded dataset.
    :param n_experiences: The number of experiences in the current scenario.
    :param return_task_id: if True, for every experience the task id is returned
        and the Scenario is Multi Task. This means that the scenario returned
        will be of type ``NCMultiTaskScenario``. If false the task index is
        not returned (default to 0 for every batch) and the returned scenario
        is of type ``NCSingleTaskScenario``.
    :param seed: A valid int used to initialize the random number generator.
        Can be None.
    :param fixed_class_order: A list of class IDs used to define the class
        order. If None, value of ``seed`` will be used to define the class
        order. If non-None, ``seed`` parameter will be ignored.
        Defaults to None.
    :param train_transform: The transformation to apply to the training data,
        e.g. a random crop, a normalization or a concatenation of different
        transformations (see torchvision.transform documentation for a
        comprehensive list of possible transformations).
        If no transformation is passed, the default train transformation
        will be used.
    :param test_transform: The transformation to apply to the test data,
        e.g. a random crop, a normalization or a concatenation of different
        transformations (see torchvision.transform documentation for a
        comprehensive list of possible transformations).
        If no transformation is passed, the default test transformation
        will be used.

    :returns: A :class:`NCMultiTaskScenario` instance initialized for the MT
    scenario if the parameter ``return_task_id`` is True,
        a :class:`NCSingleTaskScenario` initialized for the SIT scenario otherwise.
        """

    print(fixed_class_order)

    if preprocessed:
        print("not available")
    else:
        train_set, test_set = _get_mini_imagenet_dataset(os.path.join(root_path, "train"))

    if return_task_id:
        return nc_benchmark(
            train_dataset=train_set,
            test_dataset=test_set,
            n_experiences=n_experiences,
            task_labels=True,
            seed=seed,
            fixed_class_order=fixed_class_order,
            per_exp_classes=None,
            class_ids_from_zero_in_each_exp=True,
            train_transform=train_transform,
            eval_transform=test_transform)
    else:
        return nc_benchmark(
            train_dataset=train_set,
            test_dataset=test_set,
            n_experiences=n_experiences,
            task_labels=False,
            seed=seed,
            fixed_class_order=fixed_class_order,
            per_exp_classes=None,
            class_ids_from_zero_from_first_exp=True,
            train_transform=train_transform,
            eval_transform=test_transform)


# def _get_mini_imagenet_dataset(path):
#     """ Create from ImageNet. """
#     train_set = MiniImageNetDataset(path, split='all')
#
#     train_set_images = np.array([np.array(img[0]) for img in train_set])
#     train_set_labels = np.array(train_set.targets)
#
#     train_x, test_x = [], []
#     train_y, test_y = [], []
#
#     for target in np.unique(train_set.targets):
#         subset_x = train_set_images[train_set_labels == target]
#         subset_y = train_set_labels[train_set_labels == target]
#         train_x.extend(subset_x[:500])
#         test_x.extend(subset_x[500:])
#         train_y.extend(subset_y[:500])
#         test_y.extend(subset_y[500:])
#
#     return XYDataset(train_x, train_y), XYDataset(test_x, test_y)


def _get_mini_imagenet_dataset(path):
    """ Create from ImageNet or load precomputed numpy arrays. """

    # 定义保存 numpy 文件的路径
    mini_imagenet_path = '/data/qingyi/miniimagenet/'
    os.makedirs(mini_imagenet_path, exist_ok=True)  # 如果目录不存在则创建
    images_file = os.path.join(mini_imagenet_path, 'train_set_images.npy')
    labels_file = os.path.join(mini_imagenet_path, 'train_set_labels.npy')

    # 检查是否有保存的 numpy 数组文件，如果有就直接加载
    if os.path.exists(images_file) and os.path.exists(labels_file):
        print("Loading precomputed numpy arrays for images and labels...")
        train_set_images = np.load(images_file)
        train_set_labels = np.load(labels_file)
    else:
        # 如果没有保存的文件，生成 MiniImageNet 数据集
        print("Generating MiniImageNet dataset and saving as numpy arrays...")
        train_set = MiniImageNetDataset(path, split='all')

        # 将图像和标签转换为 numpy 数组
        train_set_images = np.array([np.array(img[0]) for img in train_set])
        train_set_labels = np.array(train_set.targets)

        # 确保图像的维度为 (60000, 84, 84, 3) 和标签为 (60000,)
        assert train_set_images.shape == (60000, 84, 84, 3), "Unexpected shape for train_set_images!"
        assert train_set_labels.shape == (60000,), "Unexpected shape for train_set_labels!"

        # 保存生成的 numpy 数组
        np.save(images_file, train_set_images)
        np.save(labels_file, train_set_labels)

    # 分割数据集
    train_x, test_x = [], []
    train_y, test_y = [], []

    # 对每个类别进行切分，将前500张图片放入训练集，后500张放入测试集
    for target in np.unique(train_set_labels):
        subset_x = train_set_images[train_set_labels == target]
        subset_y = train_set_labels[train_set_labels == target]
        train_x.extend(subset_x[:500])
        test_x.extend(subset_x[500:])
        train_y.extend(subset_y[:500])
        test_y.extend(subset_y[500:])

    # 转换为 numpy 数组
    train_x = np.array(train_x)
    test_x = np.array(test_x)
    train_y = np.array(train_y)
    test_y = np.array(test_y)

    # 返回训练集和测试集
    return XYDataset(train_x, train_y), XYDataset(test_x, test_y)


class XYDataset(Dataset):
    """ Template Dataset with Labels """

    def __init__(self, x, y, **kwargs):
        self.x, self.targets = x, y
        for name, value in kwargs.items():
            setattr(self, name, value)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x, y = self.x[idx], self.targets[idx]
        return x, y


__all__ = [
    'SplitMiniImageNet'
]
