from copy import deepcopy
import csv
import hashlib
import pathlib
from typing import Tuple, Dict, Union, Sequence, Optional, List

import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader, ConcatDataset, Dataset, Subset
from torchvision import transforms

from .transforms import assemble_transform
from ..utils import ROOT_DIR, get_logger
from . import custom_transforms

DATAPATH: pathlib.Path = ROOT_DIR / "data"
CSV_KWARGS = {"delimiter": ",", "quotechar":  "|", "quoting": csv.QUOTE_MINIMAL}
logger = get_logger("datasets")

NUM_WORKERS = 4
BATCH_SIZE = 32


class DualDomainDataset(Dataset):
    def __init__(self, base_dataset, t0: transforms.Compose, t1: transforms.Compose):
        super().__init__()
        self.dataset: Dataset = base_dataset

        self.t0: transforms.Compose = t0
        self.t1: transforms.Compose = t1

        self.to_tensor = transforms.ToTensor()

    def __getitem__(self, index):
        data, target = self.dataset[index]
        return self.t0(data), self.t1(data), target

    def __len__(self):
        return len(self.dataset)

def get_subsets(dataset_name: str, dataset: Dataset, fractions: List, shuffle: bool = True,
                balance: bool = True, batch_size: int = BATCH_SIZE) -> Tuple:

    seed = int(hashlib.sha1(dataset_name.encode('utf-8')).hexdigest(), 16) % (2 ** 32)
    logger.info(f"Generating splits using local seed {seed} for dataset {dataset_name}")

    return split_dataset(dataset, fractions, seed=seed, shuffle=shuffle, balance=balance, batch_size=batch_size)


def split_dataset(dataset: Dataset, fractions: List, seed: int = 1, shuffle: bool = True,
                  balance: bool = False, batch_size: int = BATCH_SIZE) -> Tuple[Subset]:
    """
    Splits the dataset into multiple subsets according to the fractions specified in the fractions list
    :param dataset: Dataset to split
    :param fractions: each element in the list belongs to the fraction used for the specific split.
    The sum of fractions needs to be lower or equal to one
    :param seed: Seed to use for random shuffling
    :param shuffle: Boolean option weather to shuffle the dataset, before splitting it into subsets
    :param balance: Boolean option weather to balance the number of images per class so each class is included in the
    subset with the same distribution as in the original dataset (no sampling with replacement)
    :param batch_size: Batch size to use for sequential sampling of labels
    :return: Tuple of subsets representing splitted dataset with the same length as the fractions list
    """

    fraction_sum = 0
    for fraction in fractions:
        fraction_sum += fraction
    assert fraction_sum <= 1

    dataloader_sequential_sampler = DataLoader(dataset, batch_size=batch_size,
                                               shuffle=False, num_workers=NUM_WORKERS)
    if balance:
        labels = None
        labels_start_idx = 0
        for batch_idx, (_, target) in enumerate(dataloader_sequential_sampler):
            if labels is None:
                labels = target.new_zeros(size=(len(dataset), ))

            labels_end_idx = labels_start_idx + target.size(0)
            labels[labels_start_idx:labels_end_idx] = target

            labels_start_idx = labels_end_idx

    indices = torch.arange(len(dataset))

    if shuffle:
        rng = np.random.RandomState(seed=seed)
        np_indices = indices.detach().numpy()
        rng.shuffle(np_indices)
        indices = torch.tensor(np_indices)

        if balance:
            np_labels = labels.detach().numpy().copy()
            rng.shuffle(np_labels)
            rng.seed(seed=seed)
            labels_shuffled = torch.tensor(np_labels)

    else:
        if balance:
            labels_shuffled = labels

    selected_indices: List = []
    if balance:
        classes, class_counts = torch.unique(labels_shuffled, sorted=True, return_counts=True)

        for class_idx, class_id in enumerate(classes):
            class_indices = (labels_shuffled == class_id).nonzero().squeeze()

            class_start_idx = 0
            for fraction_idx, fraction in enumerate(fractions):
                class_end_idx = class_start_idx + int(fraction * class_counts[class_idx])

                if len(selected_indices) <= fraction_idx:
                    selected_indices.append(class_indices[class_start_idx:class_end_idx])
                else:
                    selected_indices[fraction_idx] = torch.cat((selected_indices[fraction_idx],
                                                                class_indices[class_start_idx:class_end_idx]), dim=0)
                class_start_idx = class_end_idx

    else:
        start_idx = 0
        for fraction in fractions:
            final_idx = start_idx + int(fraction * len(dataset))
            selected_indices.append(torch.arange(start=start_idx, end=final_idx))
            start_idx = final_idx

    subsets = tuple()
    for fraction_indices in selected_indices:
        fraction_indices_unshuffled = indices[fraction_indices]
        subsets += tuple([Subset(dataset, fraction_indices_unshuffled)])

    return subsets


############################################################################################################
# DATASET GETTERS
############################################################################################################

def get_mnist_dataset(transformer: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.MNIST(str(DATAPATH / "mnist"), train=True, download=True, transform=transformer)
    test_data = torchvision.datasets.MNIST(str(DATAPATH / "mnist"), train=False, transform=transformer)
    return train_data, test_data


def get_cifar10_dataset(transformer: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.CIFAR10(str(DATAPATH / "cifar10"), train=True, download=True,
                                              transform=transformer)
    test_data = torchvision.datasets.CIFAR10(str(DATAPATH / "cifar10"), train=False, download=True,
                                             transform=transformer)
    return train_data, test_data


def get_fashion_mnist_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.FashionMNIST(str(DATAPATH / "fashion_data"),
                                                   train=True, download=True, transform=all_transforms)
    test_data = torchvision.datasets.FashionMNIST(str(DATAPATH / "fashion_data"),
                                                  train=False, transform=all_transforms)
    return train_data, test_data

def get_svhn_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:
    root = str(DATAPATH / "svhn")
    train_data = torchvision.datasets.SVHN(root, split='train', download=True, transform=all_transforms)
    extra_data = torchvision.datasets.SVHN(root, split='extra', download=True, transform=all_transforms)
    test_data = torchvision.datasets.SVHN(root, split='test', download=True, transform=all_transforms)

    #extended_train_data = ChainDataset([train_data, extra_data])
    extended_train_data = ConcatDataset((train_data, extra_data))
    return extended_train_data, test_data

def get_imagenet_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:

    DATAPATH= "/media/sdd/datasets/imagenet/"

    """
   __imagenet_pca is to be used with custom_transforms.Lighting class.
    It is the color augmentation AlexNet style based on ResNet paper.
    Implementation is taken from: https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/fastai_imagenet.py
    """
    __imagenet_pca = {
        'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
        'eigvec': torch.Tensor([
            [-0.5675, 0.7192, 0.4009],
            [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948, 0.4203],
        ])
    }

    train_transforms = transforms.Compose([custom_transforms.RandomResize((256, 480)),
                                           transforms.RandomCrop(224),
                                           transforms.RandomHorizontalFlip(p=0.5),
                                           transforms.ToTensor(),
                                           transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                           custom_transforms.Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec'])
                                           ])

    val_transforms = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                         ])

    train_data = torchvision.datasets.ImageNet(str(DATAPATH), split='train', transform=train_transforms)
    val_data = torchvision.datasets.ImageNet(str(DATAPATH), split='val', transform=val_transforms)

    test_data = torchvision.datasets.ImageNet(str(DATAPATH), split='val', transform=val_transforms) # Test dataset has no GT. Use val as test.

    return train_data, val_data, test_data

def get_emnist_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:
    root = str(DATAPATH / "letters_data")
    train_data = torchvision.datasets.EMNIST(root, split='letters', train=True, download=True, transform=all_transforms)
    test_data = torchvision.datasets.EMNIST(root, split='letters', train=False, transform=all_transforms)
    return train_data, test_data


def get_kaist_ic_dataset(all_transforms: transforms.Compose, dataset_domain: List):
    kaist_ic_root = DATAPATH / "kaist_ic"

    if "day" in dataset_domain:
        time_of_day = "day"
    elif "night" in dataset_domain:
        time_of_day = "night"
    else:
        raise ValueError

    if "rgb" in dataset_domain:
        image_type = "visible"
    elif "ir" in dataset_domain:
        image_type = "lwir"
    else:
        raise ValueError

    train_root = kaist_ic_root / "train" / time_of_day / image_type
    val_root = kaist_ic_root / "val" / time_of_day / image_type
    test_root = kaist_ic_root / "test" / time_of_day / image_type

    train_data = torchvision.datasets.ImageFolder(str(train_root), transform=all_transforms)
    val_data = torchvision.datasets.ImageFolder(str(val_root), transform=all_transforms)
    test_data = torchvision.datasets.ImageFolder(str(test_root), transform=all_transforms)

    return train_data, val_data, test_data


def get_syn2real_dataset(all_transforms: transforms.Compose, dataset_domain: List):
    syn2real_root = DATAPATH / "syn2real"

    domain = dataset_domain[0]

    if domain not in ["synthetic", "coco"]:
        raise ValueError

    train_root = syn2real_root / domain / "train"
    val_root = syn2real_root / domain / "val"
    test_root = syn2real_root / domain / "test"

    train_data = torchvision.datasets.ImageFolder(str(train_root), transform=all_transforms)
    val_data = torchvision.datasets.ImageFolder(str(val_root), transform=all_transforms)
    test_data = torchvision.datasets.ImageFolder(str(test_root), transform=all_transforms)

    return train_data, val_data, test_data


def get_office31_dataset(all_transforms: transforms.Compose, dataset_domain: List):
    office31_root = DATAPATH / "office31"

    datasets = []
    for domain in dataset_domain:
        if domain not in ["amazon", "dslr", "webcam"]:
            raise ValueError

        root = office31_root / domain / "images"

        datasets.append(torchvision.datasets.ImageFolder(str(root), transform=all_transforms))

    if len(datasets) > 1:
        concatted_dataset = ConcatDataset(datasets)
    elif len(datasets) == 1:
        concatted_dataset = datasets[0]
    else:
        raise ValueError

    return tuple([concatted_dataset])


def get_raw_datasets(dataset_name: str, transformer: transforms.Compose, dataset_domain: List = []):
    if dataset_name == "MNIST":
        return get_mnist_dataset(transformer)
    elif dataset_name == "CIFAR10":
        return get_cifar10_dataset(transformer)
    elif dataset_name == "FashionMNIST":
        return get_fashion_mnist_dataset(transformer)
    elif dataset_name == "SVHN":
        return get_svhn_dataset(transformer)
    elif dataset_name == "ImageNet":
        return get_imagenet_dataset(transformer)
    elif dataset_name == "EMNIST":
        return get_emnist_dataset(transformer)
    elif dataset_name == "Toy":
        return get_toy_dataset(transformer)
    elif dataset_name == "KAIST-IC":
        return get_kaist_ic_dataset(transformer, dataset_domain)
    elif dataset_name == "SYN2REAL":
        return get_syn2real_dataset(transformer, dataset_domain)
    elif dataset_name == "Office31":
        return get_office31_dataset(transformer, dataset_domain)
    else:
        raise NotImplementedError

############################################################################################################
############################################################################################################


def assert_domain_match(domain0, domain1):
    for key in domain0:
        if key != "transform":
            assert domain1[key] == domain0[key], "PairwiseDataloader: " \
                                                 "Domains must match in all keys except for <transform>"
        else:
            if domain1[key] == domain0[key]:
                logger.critical("Used a PairwiseDataloader with two exactly matching domains.")


def parse_domain_config(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[str, int, bool, Tuple[Dataset, ...],
                                                                             transforms.Compose]:
    """
    This function handles the differentiation between a single task config and two task configs.
    If two task configs are passed a DualDomain dataset is wrapped around the raw dataset obtained
    using get_raw_dataset, otherwise the raw datasets are returned.
    :param domain_config:
    :return:
    """
    transformers: Optional[Tuple[transforms.Compose, transforms.Compose]]

    if isinstance(domain_config, dict):
        transformers = assemble_transform(domain_config)
        shuffle = domain_config["shuffle"]
        batch_size = domain_config["batch_size"]
        dataset_name = domain_config["dataset"]
        dataset_domain = domain_config.get("dataset_domain", [])

    elif len(domain_config) == 1:
        transformers = assemble_transform(domain_config[0])
        shuffle = domain_config[0]["shuffle"]
        batch_size = domain_config[0]["batch_size"]
        dataset_name = domain_config[0]["dataset"]
        dataset_domain = domain_config[0].get("dataset_domain", [])

    elif len(domain_config) == 2:
        assert_domain_match(domain_config[0], domain_config[1])
        transformers = None, None
        shuffle = domain_config[0]["shuffle"]
        batch_size = domain_config[0]["batch_size"]
        dataset_name = domain_config[0]["dataset"]
        dataset_domain = domain_config[0].get("dataset_domain", [])
    else:
        raise ValueError("You can pass a maximum of two domain configs to create a dual domain dataloader")

    data_augmentation = domain_config.get("data_augmentation", False)
    if data_augmentation:
        # if data augmentation is specified as an option, we only apply the basic, necessary transformations
        # all other transformations are later only applied to the train dataset using the data_augment_transformer
        raw_transformer = transformers[1]
        data_augment_transformer = transformers[0]
    else:
        raw_transformer = transformers[0]
        data_augment_transformer = None

    datasets = get_raw_datasets(dataset_name, raw_transformer, dataset_domain)

    if not isinstance(domain_config, dict):
        t0 = assemble_transform(domain_config[0])
        t1 = assemble_transform(domain_config[1])
        datasets = (DualDomainDataset(dataset, t0, t1) for dataset in datasets)

    return dataset_name, batch_size, shuffle, datasets, data_augment_transformer


def get_torchvision_dataloaders(task_config: Union[Dict, Sequence[Dict]], val_fraction: float, num_workers: int,
                                val_from_train: bool = True, balance: bool = True) \
        -> Tuple[DataLoader, DataLoader, DataLoader]:
    assert val_fraction < 1

    if val_from_train == True:
        dataset_name, batch_size, shuffle, (train_data, test_data), \
            data_augment_transformer = parse_domain_config(task_config)

        train_subset, val_subset = get_subsets(dataset_name, train_data, [1 - val_fraction, val_fraction], shuffle,
                                               batch_size=batch_size, balance=balance)
    else:
        dataset_name, batch_size, shuffle, (train_data, val_data, test_data), \
        data_augment_transformer = parse_domain_config(task_config)

        train_subset, _ = get_subsets(dataset_name, train_data, [1 - val_fraction, val_fraction], shuffle,
                                      balance=balance)
        val_subset, _ = get_subsets(dataset_name, val_data, [1 - val_fraction, val_fraction], shuffle, balance=balance)


    dataloaders = create_dataloaders(train_subset, val_subset, test_data, batch_size, shuffle, num_workers,
                                     data_augment_transformer)
    return dataloaders

def create_dataloaders(train_dataset: Dataset, val_dataset: Dataset, test_dataset: Dataset,
                       batch_size, shuffle: bool, num_workers: int,
                       data_augment_transformer: transforms.Compose = None) -> Tuple[DataLoader, DataLoader, DataLoader]:

    if data_augment_transformer is not None:
        train_dataset = deepcopy(train_dataset)

        def recursively_apply_transform(dataset: Dataset, transform: transforms.Compose):
            if hasattr(dataset, "transforms"):
                # final layer of dataset
                dataset.transforms = torchvision.datasets.vision.StandardTransform(transform)
            elif hasattr(dataset, "datasets"):
                # for ConcatDataset
                for subset in dataset.datasets:
                    recursively_apply_transform(subset, transform)
            elif hasattr(dataset, "dataset"):
                # for subsets
                recursively_apply_transform(dataset.dataset, transform)
            else:
                raise NotImplementedError

        recursively_apply_transform(train_dataset, data_augment_transformer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return train_loader, valid_loader, test_loader


############################################################################################################
# Dataloader creator functions for all types of datasets
############################################################################################################


def get_mnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    MNIST dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_cifar10_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Creates CIFAR10 dataloaders with (32, 32) sized images. Hard coded split between train and testset.
    1.0 is 100% training set.
    :param task_config:
    :param all_transforms:
    :return:
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_fashion_mnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    FashionMNIST dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_svhn_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    SVHN dataloader with (32, 32) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)

def get_imagenet_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    ImageNet dataloader with (224, 224) sized images after applying data augmentation.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.0, val_from_train=False, balance=False)

def get_emnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Letters MNIST (EMNIST) dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_kaist_ic_dataloaders(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    KAIST multispectral pedestrian dataset transformed into image classification dataset
    dataloader with (64, 64) sized images.
    """
    dataset_name, batch_size, shuffle, \
        (train_data, val_data, test_data), data_augment_transformer = parse_domain_config(domain_config)

    return create_dataloaders(train_data, val_data, test_data, batch_size, shuffle, NUM_WORKERS,
                              data_augment_transformer)


def get_syn2real_dataloaders(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Syn2Real dataset sampled with (64, 64) sized images.
    """
    dataset_name, batch_size, shuffle, (train_data, val_data, test_data), \
        data_augment_transformer = parse_domain_config(domain_config)

    return create_dataloaders(train_data, val_data, test_data, batch_size, shuffle, NUM_WORKERS,
                              data_augment_transformer)


def get_office31_dataloaders(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Office 31 dataset
    """
    dataset_name, batch_size, shuffle, datasets, data_augment_transformer = parse_domain_config(domain_config)

    train_data, val_data, test_data = get_subsets(dataset_name, datasets[0], [0.6, 0.2, 0.2], shuffle, balance=True,
                                                  batch_size=batch_size)

    return create_dataloaders(train_data, val_data, test_data, batch_size, shuffle, NUM_WORKERS,
                              data_augment_transformer)
