import os
import random
import tarfile

import numpy as np
import torch
import torchvision


class FilteredDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, filter_func=None, transform=None, metainfo_func=None):
        super().__init__()
        self.indices = []
        self.dataset = dataset
        self.filter_func = filter_func
        self.transform = transform
        if metainfo_func is None:
            metainfo_func = lambda dataset, i: dataset[i]
        self.metainfo_func = metainfo_func
        self.filter()

    def filter(self):
        if self.filter_func is None:
            self.indices = list(range(len(self.dataset)))
        else:
            for i in range(len(self.dataset)):
                if self.filter_func(self.metainfo_func(self.dataset, i)):
                    self.indices.append(i)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        x = self.dataset[self.indices[index]]
        if self.transform is not None:
            x = self.transform(x)
        return x


def get_dataloader(dataset, batch_size, shuffle=False):
    return torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=4, shuffle=shuffle, pin_memory=True
    )


def create_cifar_filter_func(classes, get_val):
    if get_val:
        return lambda x: (x - 50) in classes
    else:
        return lambda x: x in classes


def create_task_transform(class_to_task_class):
    return torchvision.transforms.Lambda(lambda x: (x[0], class_to_task_class[x[1]]))


def get_split_cifar100(dataset_location, batch_size, get_val=False, saved_tasks=None):
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )
    test_dataset = torchvision.datasets.CIFAR100(
        dataset_location,
        transform=transform,
        target_transform=None,
        train=False,
    )
    train_dataset = torchvision.datasets.CIFAR100(
        dataset_location,
        transform=transform,
        target_transform=None,
        train=True,
    )
    classes = list(range(100))
    random.shuffle(classes)
    tasks = [classes[i * 5 : (i + 1) * 5] for i in range(20)]
    if saved_tasks is not None:
        tasks = saved_tasks
    class_to_task_class = {}
    for task in tasks:
        for task_class, class_id in enumerate(task):
            class_to_task_class[class_id] = task_class

    dataloaders = []
    for task_id, task in enumerate(tasks):
        train_subset = FilteredDataset(
            train_dataset,
            create_cifar_filter_func(task, False),
            transform=create_task_transform(class_to_task_class),
            metainfo_func=lambda dataset, i: dataset.targets[i],
        )
        if get_val:
            train_subset, test_subset = get_random_split(train_subset, 0.9)
        else:
            test_subset = FilteredDataset(
                test_dataset,
                create_cifar_filter_func(task, False),
                transform=create_task_transform(class_to_task_class),
                metainfo_func=lambda dataset, i: dataset.targets[i],
            )
        task_dataloaders = {
            "train": get_dataloader(train_subset, batch_size, shuffle=True),
            "test": get_dataloader(test_subset, batch_size, shuffle=False),
        }
        dataloaders.append(task_dataloaders)
    return dataloaders, tasks


def get_split_mnist(dataset_location, batch_size, get_val=False, saved_tasks=None):
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Lambda(lambda x: x.convert("RGB")),
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )
    train_dataset = torchvision.datasets.MNIST(
        dataset_location, True, transform=transform, download=False
    )
    test_dataset = torchvision.datasets.MNIST(
        dataset_location, False, transform=transform, download=False
    )
    classes = list(range(10))
    random.shuffle(classes)
    tasks = [classes[i * 2 : (i + 1) * 2] for i in range(5)]
    if saved_tasks is not None:
        tasks = saved_tasks
    class_to_task_class = {}
    for task in tasks:
        for task_class, class_id in enumerate(task):
            class_to_task_class[class_id] = task_class

    dataloaders = []
    for task_id, task in enumerate(tasks):
        train_subset = FilteredDataset(
            train_dataset,
            create_cifar_filter_func(task, False),
            transform=create_task_transform(class_to_task_class),
            metainfo_func=lambda dataset, i: dataset.targets[i],
        )
        if get_val:
            train_subset, test_subset = get_random_split(train_subset, 0.9)
        else:
            test_subset = FilteredDataset(
                test_dataset,
                create_cifar_filter_func(task, False),
                transform=create_task_transform(class_to_task_class),
                metainfo_func=lambda dataset, i: dataset.targets[i],
            )
        task_dataloaders = {
            "train": get_dataloader(train_subset, batch_size, shuffle=True),
            "test": get_dataloader(test_subset, batch_size, shuffle=False),
        }
        dataloaders.append(task_dataloaders)
    return dataloaders, tasks


def get_cifar_50(dataset_location, batch_size, get_val=False, saved_tasks=None):
    if get_val:
        target_transform = torchvision.transforms.Lambda(lambda x: x - 50)
    else:
        target_transform = None
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )

    test_dataset = torchvision.datasets.CIFAR100(
        dataset_location,
        transform=transform,
        target_transform=target_transform,
        train=False,
    )
    train_dataset = torchvision.datasets.CIFAR100(
        dataset_location,
        transform=transform,
        target_transform=target_transform,
        train=True,
    )
    classes = list(range(50))
    random.shuffle(classes)
    tasks = [classes[i * 10 : (i + 1) * 10] for i in range(5)]

    if saved_tasks is not None:
        tasks = saved_tasks
    class_to_task_class = {}
    for task in tasks:
        for task_class, class_id in enumerate(task):
            class_to_task_class[class_id] = task_class

    dataloaders = []
    for task_id, task in enumerate(tasks):
        train_subset = FilteredDataset(
            train_dataset,
            create_cifar_filter_func(task, get_val),
            transform=create_task_transform(class_to_task_class),
            metainfo_func=lambda dataset, i: dataset.targets[i],
        )
        test_subset = FilteredDataset(
            test_dataset,
            create_cifar_filter_func(task, get_val),
            transform=create_task_transform(class_to_task_class),
            metainfo_func=lambda dataset, i: dataset.targets[i],
        )
        task_dataloaders = {
            "train": get_dataloader(train_subset, batch_size, shuffle=True),
            "test": get_dataloader(test_subset, batch_size, shuffle=False),
        }
        dataloaders.append(task_dataloaders)
    return dataloaders, tasks


def get_5_dataset(dataset_location, batch_size, get_val=False, saved_tasks=None):
    tasks = [
        "cifar10",
        "mnist",
        "svhn",
        "not_mnist",
        "fashion_mnist",
    ]
    random.shuffle(tasks)
    if saved_tasks is not None:
        tasks = saved_tasks
    dataloaders = []
    for task in tasks:
        dataloaders.append(
            get_dataset(task, batch_size, dataset_location, get_val=get_val)
        )
    return dataloaders, tasks


def get_random_split(dataset, split):
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    split_index = int(len(indices) * split)
    dataset_1 = torch.utils.data.Subset(dataset, indices[:split_index])
    dataset_2 = torch.utils.data.Subset(dataset, indices[split_index:])
    return dataset_1, dataset_2


def check_not_mnist_files(path):
    bad_paths = [
        "RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png",
        "Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png",
    ]
    for bad_path in bad_paths:
        if bad_path in path:
            return False
    return True


def get_not_mnist(dataset_location, transform):
    """
    Parses and returns the downloaded notMNIST dataset
    """
    tar_path = os.path.join(dataset_location, "notMNIST_small.tar.gz")
    with tarfile.open(tar_path) as tar:
        tar.extractall(dataset_location)
    dataset = torchvision.datasets.ImageFolder(
        os.path.join(dataset_location, "notMNIST_small"),
        transform=transform,
        is_valid_file=check_not_mnist_files,
    )
    train_dataset, test_dataset = get_random_split(dataset, 0.9)
    return train_dataset, test_dataset


def get_dataset(dataset_name, batch_size, dataset_location, get_val=False):
    dataset_location = os.path.join(dataset_location, dataset_name)
    transforms_list = [
        torchvision.transforms.Resize(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
    print(dataset_location)
    if dataset_name in ["mnist", "fashion_mnist", "not_mnist"]:
        transforms_list.insert(
            0, torchvision.transforms.Lambda(lambda x: x.convert("RGB"))
        )
    transform = torchvision.transforms.Compose(transforms_list)
    if dataset_name == "cifar10":
        train_dataset = torchvision.datasets.CIFAR10(
            dataset_location, True, transform=transform, download=False
        )
        test_dataset = torchvision.datasets.CIFAR10(
            dataset_location, False, transform=transform, download=False
        )
    elif dataset_name == "svhn":
        train_dataset = torchvision.datasets.SVHN(
            dataset_location, "train", transform=transform, download=False
        )
        test_dataset = torchvision.datasets.SVHN(
            dataset_location, "test", transform=transform, download=False
        )
    elif dataset_name == "mnist":
        train_dataset = torchvision.datasets.MNIST(
            dataset_location, True, transform=transform, download=False
        )
        test_dataset = torchvision.datasets.MNIST(
            dataset_location, False, transform=transform, download=False
        )
    elif dataset_name == "fashion_mnist":
        train_dataset = torchvision.datasets.FashionMNIST(
            dataset_location, True, transform=transform, download=False
        )
        test_dataset = torchvision.datasets.FashionMNIST(
            dataset_location, False, transform=transform, download=False
        )
    elif dataset_name == "not_mnist":
        train_dataset, test_dataset = get_not_mnist(dataset_location, transform)

    class_to_task_class = list(range(10))
    if get_val:
        train_dataset, test_dataset = get_random_split(train_dataset, 0.9)
    train_dataset = FilteredDataset(
        train_dataset, transform=create_task_transform(class_to_task_class)
    )
    test_dataset = FilteredDataset(
        test_dataset, transform=create_task_transform(class_to_task_class)
    )
    dataloaders = {
        "train": get_dataloader(train_dataset, batch_size, shuffle=True),
        "test": get_dataloader(test_dataset, batch_size),
    }
    return dataloaders
