"""
    function for loading datasets
    contains: 
        CIFAR-10
        CIFAR-100   
"""
import os
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, CelebA
import glob
import copy
from shutil import move

import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm

CIFAR_100_fine_labels = [
    "apple",  # id 0
    "aquarium_fish",
    "baby",
    "bear",
    "beaver",
    "bed",
    "bee",
    "beetle",
    "bicycle",
    "bottle",
    "bowl",
    "boy",
    "bridge",
    "bus",
    "butterfly",
    "camel",
    "can",
    "castle",
    "caterpillar",
    "cattle",
    "chair",
    "chimpanzee",
    "clock",
    "cloud",
    "cockroach",
    "couch",
    "crab",
    "crocodile",
    "cup",
    "dinosaur",
    "dolphin",
    "elephant",
    "flatfish",
    "forest",
    "fox",
    "girl",
    "hamster",
    "house",
    "kangaroo",
    "computer_keyboard",
    "lamp",
    "lawn_mower",
    "leopard",
    "lion",
    "lizard",
    "lobster",
    "man",
    "maple_tree",
    "motorcycle",
    "mountain",
    "mouse",
    "mushroom",
    "oak_tree",
    "orange",
    "orchid",
    "otter",
    "palm_tree",
    "pear",
    "pickup_truck",
    "pine_tree",
    "plain",
    "plate",
    "poppy",
    "porcupine",
    "possum",
    "rabbit",
    "raccoon",
    "ray",
    "road",
    "rocket",
    "rose",
    "sea",
    "seal",
    "shark",
    "shrew",
    "skunk",
    "skyscraper",
    "snail",
    "snake",
    "spider",
    "squirrel",
    "streetcar",
    "sunflower",
    "sweet_pepper",
    "table",
    "tank",
    "telephone",
    "television",
    "tiger",
    "tractor",
    "train",
    "trout",
    "tulip",
    "turtle",
    "wardrobe",
    "whale",
    "willow_tree",
    "wolf",
    "woman",
    "worm",
]

CIFAR_100_super_class = [
    "aquatic mammals",
    "fish",
    "flowers",
    "food containers",
    "fruit and vegetables",
    "household electrical device",
    "household furniture",
    "insects",
    "large carnivores",
    "large man-made outdoor things",
    "large natural outdoor scenes",
    "large omnivores and herbivores",
    "medium-sized mammals",
    "non-insect invertebrates",
    "people",
    "reptiles",
    "small mammals",
    "trees",
    "vehicles 1",
    "vehicles 2",
]

mapping_coarse_fine = {
    "aquatic mammals": ["beaver", "dolphin", "otter", "seal", "whale"],
    "fish": ["aquarium_fish", "flatfish", "ray", "shark", "trout"],
    "flowers": ["orchid", "poppy", "rose", "sunflower", "tulip"],
    "food containers": ["bottle", "bowl", "can", "cup", "plate"],
    "fruit and vegetables": ["apple", "mushroom", "orange", "pear", "sweet_pepper"],
    "household electrical device": [
        "clock",
        "computer_keyboard",
        "lamp",
        "telephone",
        "television",
    ],
    "household furniture": ["bed", "chair", "couch", "table", "wardrobe"],
    "insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
    "large carnivores": ["bear", "leopard", "lion", "tiger", "wolf"],
    "large man-made outdoor things": [
        "bridge",
        "castle",
        "house",
        "road",
        "skyscraper",
    ],
    "large natural outdoor scenes": ["cloud", "forest", "mountain", "plain", "sea"],
    "large omnivores and herbivores": [
        "camel",
        "cattle",
        "chimpanzee",
        "elephant",
        "kangaroo",
    ],
    "medium-sized mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"],
    "non-insect invertebrates": ["crab", "lobster", "snail", "spider", "worm"],
    "people": ["baby", "boy", "girl", "man", "woman"],
    "reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
    "small mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
    "trees": ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"],
    "vehicles 1": ["bicycle", "bus", "motorcycle", "pickup_truck", "train"],
    "vehicles 2": ["lawn_mower", "rocket", "streetcar", "tank", "tractor"],
}


def cifar10_dataloaders_no_val(
    batch_size=128, data_dir="datasets/cifar10", num_workers=2
):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    print(
        "Dataset information: CIFAR-10\t 45000 images for training \t 5000 images for validation\t"
    )
    print("10000 images for testing\t no normalize applied in data_transform")
    print("Data augmentation = randomcrop(32,4) + randomhorizontalflip")

    train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=True)
    val_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True)
    test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True)

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader


def svhn_dataloaders(
    batch_size=128,
    data_dir="datasets/svhn",
    num_workers=2,
    class_to_replace: int = None,
    num_indexes_to_replace=None,
    indexes_to_replace=None,
    seed: int = 1,
    only_mark: bool = False,
    shuffle=True,
    no_aug=False,
    mode: str=None,
    ratio: float = 0.1
):
    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    print(
        "Dataset information: SVHN\t 45000 images for training \t 5000 images for validation\t"
    )

    train_set = SVHN(data_dir, split="train", transform=train_transform, download=True)

    test_set = SVHN(data_dir, split="test", transform=test_transform, download=True)

    train_set.labels = np.array(train_set.labels)
    test_set.labels = np.array(test_set.labels)

    rng = np.random.RandomState(seed)
    valid_set = copy.deepcopy(train_set)
    valid_idx = []
    for i in range(max(train_set.labels) + 1):
        class_idx = np.where(train_set.labels == i)[0]
        valid_idx.append(
            rng.choice(class_idx, int(0.1 * len(class_idx)), replace=False)
        )
    valid_idx = np.hstack(valid_idx)
    train_set_copy = copy.deepcopy(train_set)

    valid_set.data = train_set_copy.data[valid_idx]
    valid_set.labels = train_set_copy.labels[valid_idx]

    train_idx = list(set(range(len(train_set))) - set(valid_idx))

    train_set.data = train_set_copy.data[train_idx]
    train_set.labels = train_set_copy.labels[train_idx]

    forgetting_index = None
    all_class_idx = dict()
    for i in range(max(train_set.labels) + 1):
        class_idx = np.where(train_set.labels == i)[0]
        all_class_idx[i] = class_idx
        if forgetting_index is None:
            forgetting_index = rng.choice(
                class_idx, int(ratio * 2 * len(class_idx)), replace=False
            )
        else:
            forgetting_index = np.hstack(
                [
                    forgetting_index,
                    rng.choice(
                        class_idx, int(ratio * 2 * len(class_idx)), replace=False
                    ),
                ]
            )
    
    if mode == "one_class":
        forgetting_cls = np.random.randint(0, 10)
        forgetting_index = all_class_idx[forgetting_cls]
        assert np.unique(train_set.labels[forgetting_index]).shape[0] == 1
        train_set.labels[forgetting_index] = -train_set.labels[forgetting_index] - 1
        test_set.data = test_set.data[test_set.labels != forgetting_cls]
        test_set.labels = test_set.labels[test_set.labels != forgetting_cls]
        print(f">>>>>>>>>>>>>>>>>>>>>> Forgetting class: {forgetting_cls}")
    elif mode == "one_class_random":
        forgetting_cls = np.random.randint(0, 10)
        forgetting_index = rng.choice(
            all_class_idx[forgetting_cls], int(0.9 * len(all_class_idx[forgetting_cls])), replace=False
        )
        train_set.labels[forgetting_index] = -train_set.labels[forgetting_index] - 1
    elif mode == "random":
        # rng.shuffle(forgetting_index)
        forgetting_index = rng.choice(
            forgetting_index, int(0.5 * len(forgetting_index)), replace=False
        )
        train_set.labels[forgetting_index] = -train_set.labels[forgetting_index] - 1

    # if class_to_replace is not None and indexes_to_replace is not None:
    #     raise ValueError(
    #         "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
    #     )
    # if class_to_replace is not None:
    #     replace_class(
    #         train_set,
    #         class_to_replace,
    #         num_indexes_to_replace=num_indexes_to_replace,
    #         seed=seed - 1,
    #         only_mark=only_mark,
    #     )
    #     if num_indexes_to_replace is None or num_indexes_to_replace == 4454:
    #         test_set.data = test_set.data[test_set.labels != class_to_replace]
    #         test_set.labels = test_set.labels[test_set.labels != class_to_replace]
    # if indexes_to_replace is not None:
    #     replace_indexes(
    #         dataset=train_set,
    #         indexes=indexes_to_replace,
    #         seed=seed - 1,
    #         only_mark=only_mark,
    #     )

    # loader_args = {'num_workers': 0, 'pin_memory': False}
    loader_args = {"num_workers": 4, "pin_memory": True}

    def _init_fn(worker_id):
        np.random.seed(int(seed))

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    val_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    return train_loader, val_loader, test_loader


def cifar100_dataloaders(
    batch_size=128,
    data_dir="datasets/cifar100",
    num_workers=2,
    class_to_replace: int = None,
    num_indexes_to_replace=None,
    indexes_to_replace=None,
    seed: int = 1,
    only_mark: bool = False,
    shuffle=True,
    no_aug=False,
    mode: str = None,
    ratio: float = 0.1,
):
    if no_aug:
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
    else:
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    # coarse_fine_mapping_index = dict()
    # coarse_fine_mapping = dict()
    # for i in mapping_coarse_fine:
    #     # print(mapping_coarse_fine[i])
    #     coarse_index = CIFAR_100_super_class.index(i)
    #     for j in mapping_coarse_fine[i]:
    #         fine_index = CIFAR_100_fine_labels.index(j)
    #         # print(j, fine_index, coarse_index, i)
    #         coarse_fine_mapping_index[fine_index] = coarse_index
    #         coarse_fine_mapping[j] = i

    print(
        "Dataset information: CIFAR-100\t 45000 images for training \t 500 images for validation\t"
    )
    print("10000 images for testing\t no normalize applied in data_transform")
    print("Data augmentation = randomcrop(32,4) + randomhorizontalflip")
    train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=True)

    test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True)
    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)
    # train_set.index_mapping = coarse_fine_mapping_index
    # train_set.label_mapping = coarse_fine_mapping
    # test_set.index_mapping = coarse_fine_mapping_index
    # test_set.label_mapping = coarse_fine_mapping

    # train_set.super_target = np.array(
    #     [coarse_fine_mapping_index[i] for i in train_set.targets])
    # test_set.super_target = np.array([])

    rng = np.random.RandomState(seed)
    valid_set = copy.deepcopy(train_set)
    valid_idx = []
    for i in range(max(train_set.targets) + 1):
        class_idx = np.where(train_set.targets == i)[0]
        valid_idx.append(
            rng.choice(class_idx, int(0.1 * len(class_idx)), replace=False)
        )
    valid_idx = np.hstack(valid_idx)
    train_set_copy = copy.deepcopy(train_set)

    valid_set.data = train_set_copy.data[valid_idx]
    valid_set.targets = train_set_copy.targets[valid_idx]

    train_idx = list(set(range(len(train_set))) - set(valid_idx))

    train_set.data = train_set_copy.data[train_idx]
    train_set.targets = train_set_copy.targets[train_idx]

    forgetting_index = None
    all_class_idx = dict()
    for i in range(max(train_set.targets) + 1):
        class_idx = np.where(train_set.targets == i)[0]
        all_class_idx[i] = class_idx
        if forgetting_index is None:
            forgetting_index = rng.choice(
                class_idx, int(ratio * 2 * len(class_idx)), replace=False
            )
        else:
            forgetting_index = np.hstack(
                [
                    forgetting_index,
                    rng.choice(
                        class_idx, int(ratio * 2 * len(class_idx)), replace=False
                    ),
                ]
            )

    if mode == "one_class":
        forgetting_cls = np.random.randint(0, 100)
        forgetting_index = all_class_idx[forgetting_cls]
        assert np.unique(train_set.targets[forgetting_index]).shape[0] == 1
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
        test_set.data = test_set.data[test_set.targets != forgetting_cls]
        test_set.targets = test_set.targets[test_set.targets != forgetting_cls]
        print(f">>>>>>>>>>>>>>>>>>>>>> Forgetting class: {forgetting_cls}")
    elif mode == "one_class_random":
        forgetting_cls = np.random.randint(0, 100)
        forgetting_index = rng.choice(
            all_class_idx[forgetting_cls], 440, replace=False
        )
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
    elif mode == "random":
        # rng.shuffle(forgetting_index)
        forgetting_index = rng.choice(
            forgetting_index, int(0.5 * len(forgetting_index)), replace=False
        )
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
    # if class_to_replace is not None and indexes_to_replace is not None:
    #     raise ValueError(
    #         "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
    #     )
    # if class_to_replace is not None:
    #     replace_class(
    #         train_set,
    #         class_to_replace,
    #         num_indexes_to_replace=num_indexes_to_replace,
    #         seed=seed - 1,
    #         only_mark=only_mark,
    #     )
    #     if num_indexes_to_replace is None:
    #         test_set.data = test_set.data[test_set.targets != class_to_replace]
    #         test_set.targets = test_set.targets[test_set.targets != class_to_replace]
    # if indexes_to_replace is not None or indexes_to_replace == 450:
    #     replace_indexes(
    #         dataset=train_set,
    #         indexes=indexes_to_replace,
    #         seed=seed - 1,
    #         only_mark=only_mark,
    #     )

    loader_args = {"num_workers": 4, "pin_memory": True}

    def _init_fn(worker_id):
        np.random.seed(int(seed))

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    val_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    return train_loader, val_loader, test_loader


def cifar100_20_dataloaders(
    batch_size=128,
    data_dir="datasets/cifar100",
    num_workers=2,
    class_to_replace: int = None,
    num_indexes_to_replace=None,
    indexes_to_replace=None,
    seed: int = 1,
    only_mark: bool = False,
    shuffle=True,
    no_aug=False,
    all_classes=False,
    mode:str = None,
):
    if no_aug:
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
    else:
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    coarse_fine_mapping_index = dict()
    coarse_fine_mapping = dict()
    for i in mapping_coarse_fine:
        # print(mapping_coarse_fine[i])
        coarse_index = CIFAR_100_super_class.index(i)
        for j in mapping_coarse_fine[i]:
            fine_index = CIFAR_100_fine_labels.index(j)
            # print(j, fine_index, coarse_index, i)
            coarse_fine_mapping_index[fine_index] = coarse_index
            coarse_fine_mapping[j] = i

    print(
        "Dataset information: CIFAR-100\t 45000 images for training \t 500 images for validation\t"
    )
    print("10000 images for testing\t no normalize applied in data_transform")
    print("Data augmentation = randomcrop(32,4) + randomhorizontalflip")
    train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=True)

    test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True)
    # train_set.fine_targets = np.array(train_set.targets)
    # test_set.fine_targets = np.array(test_set.targets)

    # corase_fine_mapping_index: dict.
    # The key is the index of the fine-grained label.
    # The value is the index of the super class label.
    train_set.index_mapping = coarse_fine_mapping_index
    train_set.label_mapping = coarse_fine_mapping
    test_set.index_mapping = coarse_fine_mapping_index
    test_set.label_mapping = coarse_fine_mapping

    # coarse_fine_mapping_index_dict: dict.
    # The key is the index of the super class label.
    # The value is the list of the index of the fine-grained label.
    coarse_fine_mapping_index_dict = dict()
    for i in coarse_fine_mapping_index:
        super_class_label_index = coarse_fine_mapping_index[i]
        if super_class_label_index not in list(coarse_fine_mapping_index_dict.keys()):
            coarse_fine_mapping_index_dict[super_class_label_index] = []
        coarse_fine_mapping_index_dict[super_class_label_index].append(i)

    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)

    # train_set.targets = np.array(
    #     [coarse_fine_mapping_index[i] for i in train_set.fine_targets])
    # test_set.targets = np.array(
    #     [coarse_fine_mapping_index[i] for i in test_set.fine_targets])

    # validation set
    rng = np.random.RandomState(seed)
    valid_set = copy.deepcopy(train_set)
    valid_idx = []
    for i in range(max(train_set.targets) + 1):
        class_idx = np.where(train_set.targets == i)[0]
        valid_idx.append(
            rng.choice(class_idx, int(0.1 * len(class_idx)), replace=False)
        )
    valid_idx = np.hstack(valid_idx)
    train_set_copy = copy.deepcopy(train_set)

    valid_set.data = train_set_copy.data[valid_idx]
    valid_set.targets = train_set_copy.targets[valid_idx]

    train_idx = list(set(range(len(train_set))) - set(valid_idx))

    train_set.data = train_set_copy.data[train_idx]
    train_set.targets = train_set_copy.targets[train_idx]

    forgetting_index = None
    all_super_class_idx = dict()
    all_fine_class_idx = dict()
    for i in range(max(train_set.targets) + 1):
        fine_class_idx = np.where(train_set.targets == i)[0]
        all_fine_class_idx[i] = fine_class_idx
        super_class = coarse_fine_mapping_index[i]
        if super_class not in list(all_fine_class_idx.keys()):
            all_super_class_idx[super_class] = fine_class_idx
        else:
            all_super_class_idx[super_class] = np.hstack([
                all_fine_class_idx[super_class], fine_class_idx])

    # one sub-class
    if mode == "one_class":
        forgetting_cls = np.random.randint(0, 100)
        forgetting_index = all_fine_class_idx[forgetting_cls]
        assert np.unique(train_set.targets[forgetting_index]).shape[0] == 1
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
    elif mode == "random":
        for i in range(20):
            selected_fine_class = rng.choice(coarse_fine_mapping_index_dict[i])
            if forgetting_index is None:
                forgetting_index = all_fine_class_idx[selected_fine_class]
            else:
                forgetting_index = np.hstack([
                    forgetting_index, all_fine_class_idx[selected_fine_class]])
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
    # if all_classes:
    #     pass
    # else:
    #     super_class_random = np.random.randint(0, 20)
    #     fine_class_random = np.random.choice(
    #         coarse_fine_mapping_index_dict[super_class_random]
    #     )

    # train_set.targets[num_indexes_to_replace] = -1

    # if class_to_replace is not None and indexes_to_replace is not None:
    #     raise ValueError(
    #         "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
    #     )
    # if class_to_replace is not None:
    #     replace_class(
    #         train_set,
    #         class_to_replace,
    #         num_indexes_to_replace=num_indexes_to_replace,
    #         seed=seed - 1,
    #         only_mark=only_mark,
    #         sub_class=False,
    #     )
    #     # All 0s go to -1
    #     if num_indexes_to_replace is None:
    #         test_set.data = test_set.data[test_set.targets != class_to_replace]
    #         test_set.targets = test_set.targets[test_set.targets != class_to_replace]
    # if indexes_to_replace is not None or indexes_to_replace == 450:
    #     replace_indexes(
    #         dataset=train_set,
    #         indexes=indexes_to_replace,
    #         seed=seed - 1,
    #         only_mark=only_mark,
    #     )

    loader_args = {"num_workers": 4, "pin_memory": True}

    def _init_fn(worker_id):
        np.random.seed(int(seed))

    train_set.fine_targets = train_set.targets
    valid_set.fine_targets = valid_set.targets
    test_set.fine_targets = test_set.targets

    train_set.targets = np.array(
        [coarse_fine_mapping_index[i] if i > -1 else -coarse_fine_mapping_index[-i - 1] - 1 for i in train_set.fine_targets]
    )
    valid_set.targets = np.array(
        [coarse_fine_mapping_index[i] for i in valid_set.fine_targets]
    )
    test_set.targets = np.array(
        [coarse_fine_mapping_index[i] for i in test_set.fine_targets]
    )

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    val_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    return train_loader, val_loader, test_loader


def cifar100_dataloaders_no_val(
    batch_size=128, data_dir="datasets/cifar100", num_workers=2
):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    print(
        "Dataset information: CIFAR-100\t 45000 images for training \t 500 images for validation\t"
    )
    print("10000 images for testing\t no normalize applied in data_transform")
    print("Data augmentation = randomcrop(32,4) + randomhorizontalflip")

    train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=True)
    val_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True)
    test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True)

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader


def celeba_dataloaders(
    batch_size=256,
    data_dir="./Unlearn-Sparse/data/celebA/",
    num_workers=4,
    forget_ratio=0.0,
    seed: int = 1,
    only_mark: bool = False,
    shuffle=True,
    no_aug=False,
):
    if no_aug:
        train_transform = transforms.Compose([transforms.ToTensor()])
    else:
        train_transform = transforms.Compose(
            [
                transforms.CenterCrop((178, 178)),
                transforms.Resize((128, 128)),
                # transforms.RandomGrayscale(p=0.5),
                # transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(0.2),
                transforms.ToTensor(),
            ]
        )
    test_transform = transforms.Compose(
        [
            transforms.CenterCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ]
    )

    def get_smile(attr):
        return attr[31]

    train_set = CelebA(
        root=data_dir,
        split="train",
        transform=train_transform,
        target_type="attr",
        target_transform=get_smile,
        download=False,
    )
    valid_set = CelebA(
        root=data_dir,
        split="valid",
        transform=test_transform,
        target_type="attr",
        target_transform=get_smile,
        download=False,
    )
    test_set = CelebA(
        root=data_dir,
        split="test",
        transform=test_transform,
        target_type="attr",
        target_transform=get_smile,
        download=False,
    )
    # TODO: remove the identity
    loader_args = {"num_workers": 4, "pin_memory": True}
    if forget_ratio != 0.0:
        size = int(np.unique(train_set.identity).shape[0] * forget_ratio)
        identities_selected = np.random.choice(
            train_set.identity.T[0], size=size, replace=False
        )
        indexes = []
        for i in identities_selected:
            indexes.append(np.where(train_set.identity == i)[0])
        indexes = np.concatenate(indexes)
        train_set.attr[indexes, 31] = -1

    def _init_fn(worker_id):
        np.random.seed(int(seed))

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    val_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    return train_loader, val_loader, test_loader


class TinyImageNetDataset(Dataset):
    def __init__(self, image_folder_set, norm_trans=None, start=0, end=-1):
        self.imgs = []
        self.targets = []
        self.transform = image_folder_set.transform
        for sample in tqdm(image_folder_set.imgs[start:end]):
            self.targets.append(sample[1])
            img = transforms.ToTensor()(Image.open(sample[0]).convert("RGB"))
            if norm_trans is not None:
                img = norm_trans(img)
            self.imgs.append(img)
        self.imgs = torch.stack(self.imgs)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        if self.transform is not None:
            return self.transform(self.imgs[idx]), self.targets[idx]
        else:
            return self.imgs[idx], self.targets[idx]


class TinyImageNet:
    """
    TinyImageNet dataset.
    """

    def __init__(self, args, normalize=False):
        self.args = args

        self.norm_layer = (
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            if normalize
            else None
        )

        self.tr_train = [
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        self.tr_test = []

        self.tr_train = transforms.Compose(self.tr_train)
        self.tr_test = transforms.Compose(self.tr_test)

        self.train_path = os.path.join(args.data_dir, "train/")
        self.val_path = os.path.join(args.data_dir, "val/")
        self.test_path = os.path.join(args.data_dir, "test/")

        if os.path.exists(os.path.join(self.val_path, "images")):
            if os.path.exists(self.test_path):
                os.rename(self.test_path, os.path.join(args.data_dir, "test_original"))
                os.mkdir(self.test_path)
            val_dict = {}
            val_anno_path = os.path.join(self.val_path, "val_annotations.txt")
            with open(val_anno_path, "r") as f:
                for line in f.readlines():
                    split_line = line.split("\t")
                    val_dict[split_line[0]] = split_line[1]

            paths = glob.glob(os.path.join(args.data_dir, "val/images/*"))
            for path in paths:
                file = path.split("/")[-1]
                folder = val_dict[file]
                if not os.path.exists(self.val_path + str(folder)):
                    os.mkdir(self.val_path + str(folder))
                    os.mkdir(self.val_path + str(folder) + "/images")
                if not os.path.exists(self.test_path + str(folder)):
                    os.mkdir(self.test_path + str(folder))
                    os.mkdir(self.test_path + str(folder) + "/images")

            for path in paths:
                file = path.split("/")[-1]
                folder = val_dict[file]
                if len(glob.glob(self.val_path + str(folder) + "/images/*")) < 25:
                    dest = self.val_path + str(folder) + "/images/" + str(file)
                else:
                    dest = self.test_path + str(folder) + "/images/" + str(file)
                move(path, dest)

            os.rmdir(os.path.join(self.val_path, "images"))

    def data_loaders(
        self,
        batch_size=128,
        data_dir="datasets/tiny",
        num_workers=2,
        class_to_replace: int = None,
        num_indexes_to_replace=None,
        indexes_to_replace=None,
        seed: int = 1,
        only_mark: bool = False,
        shuffle=True,
        no_aug=False,
    ):
        train_set = ImageFolder(self.train_path, transform=self.tr_train)
        train_set = TinyImageNetDataset(train_set, self.norm_layer)
        test_set = ImageFolder(self.test_path, transform=self.tr_test)
        test_set = TinyImageNetDataset(test_set, self.norm_layer)
        train_set.targets = np.array(train_set.targets)
        train_set.targets = np.array(train_set.targets)
        rng = np.random.RandomState(seed)
        valid_set = copy.deepcopy(train_set)
        valid_idx = []
        for i in range(max(train_set.targets) + 1):
            class_idx = np.where(train_set.targets == i)[0]
            valid_idx.append(
                rng.choice(class_idx, int(0.0 * len(class_idx)), replace=False)
            )
        valid_idx = np.hstack(valid_idx)
        train_set_copy = copy.deepcopy(train_set)

        valid_set.imgs = train_set_copy.imgs[valid_idx]
        valid_set.targets = train_set_copy.targets[valid_idx]

        train_idx = list(set(range(len(train_set))) - set(valid_idx))

        train_set.imgs = train_set_copy.imgs[train_idx]
        train_set.targets = train_set_copy.targets[train_idx]

        if class_to_replace is not None and indexes_to_replace is not None:
            raise ValueError(
                "Only one of `class_to_replace` and `indexes_to_replace` can be specified"
            )
        if class_to_replace is not None:
            replace_class(
                train_set,
                class_to_replace,
                num_indexes_to_replace=num_indexes_to_replace,
                seed=seed - 1,
                only_mark=only_mark,
            )
            if num_indexes_to_replace is None or num_indexes_to_replace == 500:
                test_set.targets = np.array(test_set.targets)
                test_set.imgs = test_set.imgs[test_set.targets != class_to_replace]
                test_set.targets = test_set.targets[
                    test_set.targets != class_to_replace
                ]
                print(test_set.targets)
                test_set.targets = test_set.targets.tolist()
        if indexes_to_replace is not None:
            replace_indexes(
                dataset=train_set,
                indexes=indexes_to_replace,
                seed=seed - 1,
                only_mark=only_mark,
            )

        loader_args = {"num_workers": 0, "pin_memory": False}

        def _init_fn(worker_id):
            np.random.seed(int(seed))

        train_loader = DataLoader(
            train_set,
            batch_size=batch_size,
            shuffle=True,
            worker_init_fn=_init_fn if seed is not None else None,
            **loader_args,
        )
        val_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            worker_init_fn=_init_fn if seed is not None else None,
            **loader_args,
        )
        test_loader = DataLoader(
            test_set,
            batch_size=batch_size,
            shuffle=False,
            worker_init_fn=_init_fn if seed is not None else None,
            **loader_args,
        )
        print(
            f"Traing loader: {len(train_loader.dataset)} images, Test loader: {len(test_loader.dataset)} images"
        )
        return train_loader, val_loader, test_loader


def cifar10_dataloaders(
    batch_size=128,
    data_dir="datasets/cifar10",
    num_workers=2,
    class_to_replace: int = None,
    num_indexes_to_replace=None,
    indexes_to_replace=None,
    seed: int = 1,
    only_mark: bool = False,
    shuffle=True,
    no_aug=False,
    mode: str = None,
    ratio: float = 0.1,
):
    if no_aug:
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
    else:
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    all_class_idx = dict()

    print(
        "Dataset information: CIFAR-10\t 45000 images for training \t 5000 images for validation\t"
    )
    print("10000 images for testing\t no normalize applied in data_transform")
    print("Data augmentation = randomcrop(32,4) + randomhorizontalflip")

    train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=True)

    test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True)

    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)

    print(seed)
    rng = np.random.RandomState(seed)
    valid_set = copy.deepcopy(train_set)
    valid_idx = []
    for i in range(max(train_set.targets) + 1):
        class_idx = np.where(train_set.targets == i)[0]
        valid_idx.append(
            rng.choice(class_idx, int(0.1 * len(class_idx)), replace=False)
        )
    valid_idx = np.hstack(valid_idx)
    train_set_copy = copy.deepcopy(train_set)

    valid_set.data = train_set_copy.data[valid_idx]
    valid_set.targets = train_set_copy.targets[valid_idx]

    train_idx = list(set(range(len(train_set))) - set(valid_idx))

    train_set.data = train_set_copy.data[train_idx]
    train_set.targets = train_set_copy.targets[train_idx]

    forgetting_index = None
    for i in range(max(train_set.targets) + 1):
        class_idx = np.where(train_set.targets == i)[0]
        all_class_idx[i] = class_idx
        if forgetting_index is None:
            forgetting_index = rng.choice(
                class_idx, int(ratio * 2 * len(class_idx)), replace=False
            )
        else:
            forgetting_index = np.hstack(
                [
                    forgetting_index,
                    rng.choice(
                        class_idx, int(ratio * 2 * len(class_idx)), replace=False
                    ),
                ]
            )

    if mode == "one_class":
        forgetting_cls = np.random.randint(0, 10)
        forgetting_index = all_class_idx[forgetting_cls]
        assert np.unique(train_set.targets[forgetting_index]).shape[0] == 1
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
        test_set.data = test_set.data[test_set.targets != forgetting_cls]
        test_set.targets = test_set.targets[test_set.targets != forgetting_cls]
        print(f">>>>>>>>>>>>>>>>>>>>>> Forgetting class: {forgetting_cls}")
    elif mode == "one_class_random":
        forgetting_cls = np.random.randint(0, 10)
        forgetting_index = rng.choice(
            all_class_idx[forgetting_cls], 4400, replace=False
        )
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1
    elif mode == "random":
        # rng.shuffle(forgetting_index)
        forgetting_index = rng.choice(
            forgetting_index, int(0.5 * len(forgetting_index)), replace=False
        )
        train_set.targets[forgetting_index] = -train_set.targets[forgetting_index] - 1

    # if class_to_replace is not None and indexes_to_replace is not None:
    #     raise ValueError(
    #         "Only one of `class_to_replace` and `indexes_to_replace` can be specified")
    # if class_to_replace is not None:
    #     replace_class(train_set, class_to_replace, num_indexes_to_replace=num_indexes_to_replace, seed=seed-1,
    #                   only_mark=only_mark)
    #     if num_indexes_to_replace is None or num_indexes_to_replace == 4500:
    #         test_set.data = test_set.data[test_set.targets != class_to_replace]
    #         test_set.targets = test_set.targets[test_set.targets !=
    #                                             class_to_replace]
    # if indexes_to_replace is not None:
    #     replace_indexes(dataset=train_set, indexes=indexes_to_replace,
    #                     seed=seed-1, only_mark=only_mark)

    loader_args = {"num_workers": 4, "pin_memory": True}

    def _init_fn(worker_id):
        np.random.seed(int(seed))

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    val_loader = DataLoader(
        valid_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=_init_fn if seed is not None else None,
        **loader_args,
    )

    return train_loader, val_loader, test_loader


def replace_indexes(
    dataset: torch.utils.data.Dataset, indexes, seed=0, only_mark: bool = False
):
    if not only_mark:
        rng = np.random.RandomState(seed)
        new_indexes = rng.choice(
            list(set(range(len(dataset))) - set(indexes)), size=len(indexes)
        )
        dataset.data[indexes] = dataset.data[new_indexes]
        try:
            dataset.targets[indexes] = dataset.targets[new_indexes]
        except:
            dataset.labels[indexes] = dataset.labels[new_indexes]
        else:
            dataset._labels[indexes] = dataset._labels[new_indexes]
    else:
        # Notice the -1 to make class 0 work
        try:
            dataset.targets[indexes] = -dataset.targets[indexes] - 1
        except:
            try:
                dataset.labels[indexes] = -dataset.labels[indexes] - 1
            except:
                dataset._labels[indexes] = -dataset._labels[indexes] - 1


def replace_class(
    dataset: torch.utils.data.Dataset,
    class_to_replace: int,
    num_indexes_to_replace: int = None,
    seed: int = 0,
    only_mark: bool = False,
    sub_class=False,
    identity=False,
):
    if class_to_replace == -1:
        indexes = np.flatnonzero(np.ones_like(dataset.targets))
    else:
        try:
            if sub_class:
                indexes = np.flatnonzero(
                    np.array(dataset.fine_targets) == class_to_replace
                )
            elif identity:
                identities_selected = np.random.choice(
                    dataset.identity.T[0], size=80, replace=False
                )
                indexes = []
                for i in identities_selected:
                    indexes.append(np.where(dataset.identity == i)[0])
                indexes = np.concatenate(indexes)
            else:
                indexes = np.flatnonzero(np.array(dataset.targets) == class_to_replace)
        except:
            try:
                indexes = np.flatnonzero(np.array(dataset.labels) == class_to_replace)
            except:
                indexes = np.flatnonzero(np.array(dataset._labels) == class_to_replace)

    if num_indexes_to_replace is not None:
        assert num_indexes_to_replace <= len(
            indexes
        ), f"Want to replace {num_indexes_to_replace} indexes but only {len(indexes)} samples in dataset"
        rng = np.random.RandomState(seed)
        indexes = rng.choice(indexes, size=num_indexes_to_replace, replace=False)
        print(f"Replacing indexes {indexes}")
    replace_indexes(dataset, indexes, seed, only_mark)


if __name__ == "__main__":
    train_loader, val_loader, test_loader = cifar10_dataloaders()
    for i, (img, label) in enumerate(train_loader):
        print(torch.unique(label).shape)
