import os
import csv

import numpy as np
import pandas as pd
import sklearn
import torch
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import datasets, transforms, utils
from sklearn.model_selection import train_test_split
from torchsampler import ImbalancedDatasetSampler
import random
import copy


def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]


def train_val_split(train_set, val_set, idx, dataset):
    if idx < 1:
        idx = int(len(train_set) * (1 - idx))
    if dataset in ["mnist", "fashionmnist", "cifar10", "cifar100"]:
        x, y = unison_shuffled_copies(train_set.data, np.array(train_set.targets))
        train_set.data = x[:idx]
        train_set.targets = y[:idx]
        val_set.data = x[idx:]
        val_set.targets = y[idx:]
    elif dataset == "gtsrb":
        p = np.random.permutation(len(train_set._samples))
        samp = np.array(train_set._samples)[p]
        train_set.data = samp[:idx, 0]
        train_set.targets = samp[:idx, 1]
        val_set.data = samp[idx:, 0]
        val_set.targets = samp[idx:, 1]
    elif dataset == "cars":
        p = np.random.permutation(len(train_set._samples))
        samp = np.array(train_set._samples)[p]
        train_set.data = samp[:idx, 0]
        train_set.targets = samp[:idx, 1]
        val_set.data = samp[idx:, 0]
        val_set.targets = samp[idx:, 1]
    elif dataset == "food":
        x, y = unison_shuffled_copies(
            np.array(train_set._image_files), np.array(train_set._labels)
        )
        train_set._image_files = x[:idx]
        train_set._labels = y[:idx]
        val_set._image_files = x[idx:]
        val_set._labels = y[idx:]
    elif dataset == "svhn":
        x, y = unison_shuffled_copies(train_set.data, np.array(train_set.labels))
        train_set.data = x[:idx]
        train_set.labels = y[:idx]
        val_set.data = x[idx:]
        val_set.labels = y[idx:]
    elif dataset == "imagenet":
        path, label = zip(*train_set.samples)
        x, y = unison_shuffled_copies(np.array(path), np.array(label))
        train_set.samples = list(zip(x[:idx], y[:idx]))
        val_set.samples = list(zip(x[idx:], y[idx:]))
    else:
        print("Validation split not implemented!")
        exit(1)
    return train_set, val_set


def adjust_gtsrb_test_set(test_set):
    samp = np.array(test_set._samples)
    test_set.data = samp[:, 0]
    test_set.targets = samp[:, 1]
    return test_set


def load_data(args):
    if args.dataset == "mnist":
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )
        test_transform_nonnorm = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        train_set = MNIST(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = MNIST(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = MNIST(args.dataset_path, train=False, transform=test_transform)
        test_set_unnorm = MNIST(
            args.dataset_path, train=False, transform=test_transform_nonnorm
        )
        num_classes = 10
        class_names = [str(i) for i in range(10)]
    if args.dataset == "fashionmnist":
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        train_set = FashionMNIST(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = FashionMNIST(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = FashionMNIST(
            args.dataset_path, train=False, transform=test_transform
        )
        test_set_unnorm = FashionMNIST(
            args.dataset_path, train=False, transform=test_transform_unnorm
        )
        num_classes = 10
        class_names = train_set.classes
    elif args.dataset == "imagenet":
        train_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )
        train_set = ImageNet(
            args.dataset_path,
            split="train",
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = ImageNet(
            args.dataset_path,
            split="train",
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = ImageNet(args.dataset_path, split="val", transform=test_transform)
        test_set_unnorm = ImageNet(
            args.dataset_path, split="val", transform=test_transform_unnorm
        )
        num_classes = 1000
        class_names = []
    elif args.dataset == "cifar10":
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        train_set = CIFAR10(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = CIFAR10(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = CIFAR10(args.dataset_path, train=False, transform=test_transform)
        test_set_unnorm = CIFAR10(
            args.dataset_path, train=False, transform=test_transform_unnorm
        )
        num_classes = 10
        class_names = train_set.classes
    elif args.dataset == "cifar100":
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                    std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404),
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                    std=(0.2673342858792401, 0.2564384629170883, 0.27615047132568404),
                ),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        train_set = CIFAR100(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = CIFAR100(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = CIFAR100(args.dataset_path, train=False, transform=test_transform)
        test_set_unnorm = CIFAR100(
            args.dataset_path, train=False, transform=test_transform_unnorm
        )
        num_classes = 100
        class_names = train_set.classes
    elif args.dataset == "svhn":
        train_transform = transforms.Compose(
            [
                transforms.RandomRotation(15),
                transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        train_set = SVHN(
            args.dataset_path,
            split="train",
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = SVHN(
            args.dataset_path,
            split="train",
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = SVHN(
            args.dataset_path, split="test", download=True, transform=test_transform
        )
        test_set_unnorm = SVHN(
            args.dataset_path,
            split="test",
            download=True,
            transform=test_transform_unnorm,
        )
        num_classes = 10
        class_names = [str(i) for i in range(10)]
    elif args.dataset == "gtsrb":
        train_transform = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.3403, 0.3121, 0.3214], std=[0.2724, 0.2608, 0.2669]
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.3403, 0.3121, 0.3214], std=[0.2724, 0.2608, 0.2669]
                ),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
            ]
        )
        train_set = GTSRB(
            args.dataset_path,
            split="train",
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = GTSRB(
            args.dataset_path,
            split="train",
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = GTSRB(
            args.dataset_path, split="test", download=True, transform=test_transform
        )
        test_set = adjust_gtsrb_test_set(test_set)
        test_set_unnorm = GTSRB(
            args.dataset_path,
            split="test",
            download=True,
            transform=test_transform_unnorm,
        )
        test_set_unnorm = adjust_gtsrb_test_set(test_set_unnorm)
        num_classes = 43
        class_names = []
    elif args.dataset == "cars":
        train_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(35),
                transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
                transforms.RandomGrayscale(p=0.5),
                transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
                transforms.RandomPosterize(bits=2, p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )

        train_set = Cars(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = Cars(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = Cars(
            args.dataset_path, train=False, download=True, transform=test_transform
        )
        test_set_unnorm = Cars(
            args.dataset_path,
            train=False,
            download=True,
            transform=test_transform_unnorm,
        )
        num_classes = 196
        class_names = train_set.classes
    elif args.dataset == "food":
        train_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(35),
                transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
                transforms.RandomGrayscale(p=0.5),
                transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
                transforms.RandomPosterize(bits=2, p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        test_transform_unnorm = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )

        train_set = Food(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        val_set = Food(
            args.dataset_path,
            train=True,
            download=True,
            transform=train_transform if not args.dp else test_transform,
        )
        train_set, val_set = train_val_split(
            train_set, val_set, args.val_frac, args.dataset
        )
        test_set = Food(args.dataset_path, train=False, transform=test_transform)
        test_set_unnorm = Food(
            args.dataset_path, train=False, transform=test_transform_unnorm
        )
        num_classes = 101
        class_names = train_set.classes
    elif "utkface" in args.dataset:
        dataFrame = pd.read_csv(f"{args.dataset_path}age_gender.gz", compression="gzip")

        age_bins = [0, 10, 15, 20, 25, 30, 40, 50, 60, 120]
        age_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8]
        dataFrame["bins"] = pd.cut(dataFrame.age, bins=age_bins, labels=age_labels)

        train_dataFrame, test_dataFrame = train_test_split(dataFrame, test_size=0.2)

        class_nums = {
            "age_num": len(dataFrame["bins"].unique()),
            "eth_num": len(dataFrame["ethnicity"].unique()),
            "gen_num": len(dataFrame["gender"].unique()),
        }

        train_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.49,), (0.23,))]
        )

        test_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.49,), (0.23,))]
        )

        test_transform_unnorm = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        train_set = UTKFace(train_dataFrame, transform=train_transform)
        test_set = UTKFace(test_dataFrame, transform=test_transform)
        val_set = train_set
        test_set_unnorm = UTKFace(test_dataFrame, transform=test_transform_unnorm)
        num_classes = 2
        class_names = []
    elif "gauss" in args.dataset:
        num_points = 10000
        outlier_points = 1

        mean = [0, 0]
        cov = [[1, 0], [0, 1]]
        g11, g12 = np.random.multivariate_normal(mean, cov, num_points).T
        g11_te, g12_te = np.random.multivariate_normal(mean, cov, num_points).T
        g1 = np.stack([g11, g12], axis=1)
        g1_te = np.stack([g11_te, g12_te], axis=1)
        l1 = np.zeros_like(g11)
        l1_te = np.zeros_like(g11_te)

        mean = [10, 0]
        cov = [[0.1, 0], [0, 0.1]]
        g21, g22 = np.random.multivariate_normal(mean, cov, outlier_points).T
        g21_te, g22_te = np.random.multivariate_normal(mean, cov, outlier_points).T
        g2 = np.stack([g21, g22], axis=1)
        g2_te = np.stack([g21_te, g22_te], axis=1)
        l2 = np.ones_like(g21)
        l2_te = np.ones_like(g21_te)

        if args.dataset != "2d_gauss":
            if args.dataset == "2d_gauss_left":
                mean = [5, 5]
                cov = [[0.2, 0], [0, 0.2]]
            elif args.dataset == "2d_gauss_center":
                mean = [10, 5]
                cov = [[0.2, 0], [0, 0.2]]
            elif args.dataset == "2d_gauss_right":
                mean = [20, 5]
                cov = [[0.2, 0], [0, 0.2]]
            g31, g32 = np.random.multivariate_normal(mean, cov, int(num_points / 5)).T
            g3 = np.stack([g31, g32], axis=1)
            l3 = np.zeros_like(g31)

            x = np.concatenate([g1, g2, g3], axis=0)
            y = np.concatenate([l1, l2, l3])
        else:
            x = np.concatenate([g1, g2], axis=0)
            y = np.concatenate([l1, l2])
            x_te = np.concatenate([g1_te, g2_te], axis=0)
            y_te = np.concatenate([l1_te, l2_te])

        train_set = torch.utils.data.TensorDataset(
            torch.Tensor(x),
            torch.Tensor(y).long(),
            torch.arange(len(y)),
        )
        val_set = train_set
        test_set = torch.utils.data.TensorDataset(
            torch.Tensor(x_te), torch.Tensor(y_te).long(), torch.arange(len(y_te))
        )
        test_set_unnorm = test_set
        num_classes = 2
        class_names = []

    elif args.dataset == "breastcancer":
        breastcancer = pd.read_csv(
            f"{args.dataset_path}/breastcancer/breast-cancer-wisconsin.csv", header=None
        )
        breastcancer = breastcancer.drop(columns=[0])
        breastcancer = sklearn.utils.shuffle(breastcancer, random_state=0)

        num_train = 500
        breastcancer[10] = breastcancer[10].replace(2, 0)
        breastcancer[10] = breastcancer[10].replace(4, 1)
        breastcancer = breastcancer[~breastcancer[6].str.contains("\?")]

        X_tr = breastcancer.iloc[:num_train, 0:9].to_numpy().astype(float)
        y_tr = breastcancer.iloc[:num_train, 9].to_numpy().astype(int)

        X_te = breastcancer.iloc[num_train:, 0:9].to_numpy().astype(float)
        y_te = breastcancer.iloc[num_train:, 9].to_numpy().astype(int)

        standard_scaler = sklearn.preprocessing.StandardScaler()
        X_tr = standard_scaler.fit_transform(X_tr)
        X_te = standard_scaler.fit_transform(X_te)

        X_tr, y_tr = unison_shuffled_copies(X_tr, y_tr)
        idx = int(len(X_tr) * (1 - args.val_frac))
        X_val = X_tr[idx:]
        y_val = y_tr[idx:]
        X_tr = X_tr[:idx]
        y_tr = y_tr[:idx]

        X_tr = torch.tensor(X_tr, dtype=torch.float32)
        y_tr = torch.tensor(y_tr, dtype=torch.int64)
        X_val = torch.tensor(X_val, dtype=torch.float32)
        y_val = torch.tensor(y_val, dtype=torch.int64)
        X_te = torch.tensor(X_te, dtype=torch.float32)
        y_te = torch.tensor(y_te, dtype=torch.int64)

        train_set = TensorDataset(X_tr, y_tr, torch.arange(len(y_tr)))
        val_set = TensorDataset(X_val, y_val, torch.arange(len(y_val)))
        test_set = TensorDataset(X_te, y_te, torch.arange(len(y_te)))
        test_set_unnorm = test_set
        num_classes = 2
        class_names = ["benign", "malignant"]
    elif args.dataset == "eicu":
        train_set = None
        val_set = None
        test_set = None
        num_classes = 2
    elif args.dataset == "mimic":
        train_set = None
        val_set = None
        test_set = None
        num_classes = 2

    train_kwargs = {
        "batch_size": args.train_batch,
        "shuffle": True,
        "num_workers": args.workers,
        "pin_memory": True,
    }
    train_kwargs_validation = {
        "batch_size": args.train_batch,
        "shuffle": False,
        "num_workers": args.workers,
        "pin_memory": True,
    }
    test_kwargs = {
        "batch_size": args.test_batch,
        "shuffle": False,
        "num_workers": args.workers,
        "pin_memory": True,
    }

    train_loader = torch.utils.data.DataLoader(train_set, **train_kwargs)
    train_loader_validation = torch.utils.data.DataLoader(
        train_set, **train_kwargs_validation
    )
    validation_loader = torch.utils.data.DataLoader(val_set, **train_kwargs_validation)
    test_loader = torch.utils.data.DataLoader(test_set, **test_kwargs)

    if args.dataset in ["cifar10", "fashionmnist"]:

        def imbalanced_data_loading(dataset, loader, num_classes, kwargs, args):
            classe_labels = range(num_classes)
            sample_probs = torch.ones(num_classes)
            sample_probs[0] = args.class_imb

            idx_to_del = [
                i
                for i, label in enumerate(loader.dataset.targets)
                if random.random() > sample_probs[label]
            ]
            imb_dataset = copy.deepcopy(dataset)
            imb_dataset.targets = np.delete(loader.dataset.targets, idx_to_del, axis=0)
            imb_dataset.data = np.delete(loader.dataset.data, idx_to_del, axis=0)
            imb_loader = torch.utils.data.DataLoader(imb_dataset, **kwargs)
            return imb_loader

        train_loader = imbalanced_data_loading(
            train_set, train_loader, num_classes, train_kwargs, args
        )
        test_loader = imbalanced_data_loading(
            test_set, test_loader, num_classes, test_kwargs, args
        )
        validation_loader = imbalanced_data_loading(
            val_set, validation_loader, num_classes, train_kwargs_validation, args
        )
        train_loader_validation = imbalanced_data_loading(
            train_set,
            train_loader_validation,
            num_classes,
            train_kwargs_validation,
            args,
        )

    if args.class_imb < 1:
        classe_labels = range(num_classes)
        sample_probs = torch.ones(num_classes)
        sample_probs[0] = args.class_imb

        idx_to_del = [
            i
            for i, label in enumerate(train_loader.dataset.targets)
            if random.random() > sample_probs[label]
        ]
        imbalanced_train_dataset = copy.deepcopy(train_set)
        imbalanced_train_dataset.targets = np.delete(
            train_loader.dataset.targets, idx_to_del, axis=0
        )
        imbalanced_train_dataset.data = np.delete(
            train_loader.dataset.data, idx_to_del, axis=0
        )
        imbalanced_train_loader = torch.utils.data.DataLoader(
            imbalanced_train_dataset, **train_kwargs
        )
        train_loader = imbalanced_train_loader

        idx_to_del = [
            i
            for i, label in enumerate(test_loader.dataset.targets)
            if random.random() > sample_probs[label]
        ]
        imbalanced_test_dataset = copy.deepcopy(test_set)
        imbalanced_test_dataset.targets = np.delete(
            test_loader.dataset.targets, idx_to_del, axis=0
        )
        imbalanced_test_dataset.data = np.delete(
            test_loader.dataset.data, idx_to_del, axis=0
        )
        imbalanced_test_loader = torch.utils.data.DataLoader(
            imbalanced_test_dataset, **train_kwargs
        )
        test_loader = imbalanced_test_loader

    return (
        train_set,
        train_loader,
        train_loader_validation,
        val_set,
        validation_loader,
        test_set,
        test_loader,
        test_set_unnorm,
        num_classes,
        class_names,
    )


class FashionMNIST(datasets.FashionMNIST):
    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(FashionMNIST, self).__init__(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index: int):
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index


class MNIST(datasets.MNIST):
    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(MNIST, self).__init__(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index: int):
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index


class CIFAR10(datasets.CIFAR10):
    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(CIFAR10, self).__init__(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index: int):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index


class CIFAR100(datasets.CIFAR100):
    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(CIFAR100, self).__init__(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index: int):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index


class SVHN(datasets.SVHN):
    def __init__(
        self, root, split="train", transform=None, target_transform=None, download=False
    ):
        super(SVHN, self).__init__(
            root,
            split=split,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index: int):
        img, target = self.data[index], int(self.labels[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index


class GTSRB(datasets.GTSRB):
    def __init__(
        self, root, split="train", transform=None, target_transform=None, download=False
    ):
        super(GTSRB, self).__init__(
            root,
            split=split,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )
        self.data = self._samples

    def get_classes(self):
        with open(self._base_folder / "class_names.csv") as csv_file:
            samples = [
                row["SignName"]
                for row in csv.DictReader(
                    csv_file, delimiter=";", skipinitialspace=True
                )
            ]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        path, target = self.data[index], int(self.targets[index])
        sample = Image.open(path).convert("RGB")

        if self.transform is not None:
            sample = self.transform(sample)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target, index


class Cars(datasets.StanfordCars):
    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(Cars, self).__init__(
            root,
            split="train" if train else "test",
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


class Food(datasets.Food101):
    def __init__(
        self, root, train=True, transform=None, target_transform=None, download=False
    ):
        super(Food, self).__init__(
            root,
            split="train" if train else "test",
            transform=transform,
            target_transform=target_transform,
            download=download,
        )

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


class ImageNet(datasets.ImageFolder):
    def __init__(self, root: str, split: str = "train", transform=None):
        super(ImageNet, self).__init__(root + "/" + split, transform=transform)

    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


class UTKFace(Dataset):
    """
    Inputs:
        dataFrame : Pandas dataFrame
        transform : The transform to apply to the dataset
    """

    def __init__(self, dataFrame, transform=None):
        # read in the transforms
        self.transform = transform

        # Use the dataFrame to get the pixel values
        data_holder = dataFrame.pixels.apply(
            lambda x: np.array(x.split(" "), dtype=float)
        )
        arr = np.stack(data_holder)
        arr = arr / 255.0
        arr = arr.astype("float32")
        arr = arr.reshape(arr.shape[0], 48, 48, 1)
        # reshape into 48x48x1
        self.data = arr

        # get the age, gender, and ethnicity label arrays
        self.age_label = np.array(
            dataFrame.bins[:]
        )  # Note : Changed dataFrame.age to dataFrame.bins
        self.gender_label = np.array(dataFrame.gender[:])
        self.eth_label = np.array(dataFrame.ethnicity[:])

    # override the length function
    def __len__(self):
        return len(self.data)

    # override the getitem function
    def __getitem__(self, index):
        # load the data at index and apply transform
        data = self.data[index]
        data = self.transform(data)

        # load the labels into a list and convert to tensors
        labels = torch.tensor(
            (self.age_label[index], self.gender_label[index], self.eth_label[index])
        )

        # return data labels
        return data, labels, index
