import os
from argparse import Namespace
from typing import List, Tuple

import numpy as np  # type: ignore
import torch
import torchvision  # type: ignore
from torch.utils.data import DataLoader
from torchvision import transforms  # type: ignore
from torchvision.datasets import CIFAR10, CIFAR100, MNIST  # type: ignore

from data import few_shot
from data.get_tabular import get_pbp_sets, pbp_sets
from data.higgs import HiggsDataset
from data.toy_classification import Circles, Gaussians, Moons
from data.toy_meta import MetaAll, MetaCircles, MetaGaussians, MetaMoons
from data.toy_regression import GappedSine

T = torch.Tensor


def get_mnist(args: Namespace) -> Tuple[DataLoader, ...]:
    train_tx = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.1307], [0.3081])
    ])

    val_tx = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.1307], [0.3081])
    ])

    train = MNIST(args.data_root, train=True, transform=train_tx, download=True)
    val = MNIST(args.data_root, train=True, transform=val_tx, download=True)
    test = MNIST(args.data_root, train=False, transform=val_tx, download=True)
    if args.get_val:
        perm = torch.randperm(train.data.size(0))
        n = int(args.val_pct * train.data.size(0))
        val_x, val_y, train_x, train_y = train.data[perm[:n]], train.targets[perm[:n]], train.data[perm[n:]], train.targets[perm[n:]]
        train.data, train.targets = train_x, train_y
        val.data, val.targets = val_x, val_y

    train_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    val_ldr = DataLoader(val, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    test_ldr = DataLoader(test, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    return train_ldr, val_ldr, test_ldr


def get_cifar10(args: Namespace) -> Tuple[DataLoader, ...]:
    # these are the simple data augmentations which were applied in the original ResNet paper: https://arxiv.org/pdf/1512.03385.pdf
    # normalization values taken from: https://github.com/Armour/pytorch-nn-practice/blob/master/utils/meanstd.py
    train_tx = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]),
    ])

    val_tx = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]),
    ])

    train = CIFAR10(args.data_root, train=True, transform=train_tx, download=True)
    val = CIFAR10(args.data_root, train=True, transform=val_tx, download=True)
    test = CIFAR10(args.data_root, train=False, transform=val_tx, download=True)

    train.targets = torch.tensor(train.targets)
    test.targets = torch.tensor(test.targets)

    if args.get_val:
        perm = np.random.permutation(train.data.shape[0])
        n = int(args.val_pct * train.data.shape[0])
        val_x, val_y, train_x, train_y = train.data[perm[:n]], train.targets[perm[:n]], train.data[perm[n:]], train.targets[perm[n:]]
        train.data, train.targets = train_x, train_y
        val.data, val.targets = val_x, val_y

    train_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    val_ldr = DataLoader(val, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    test_ldr = DataLoader(test, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    return train_ldr, val_ldr, test_ldr


def get_cifar100(args: Namespace) -> Tuple[DataLoader, ...]:
    # these are the simple data augmentations which were applied in the original ResNet paper: https://arxiv.org/pdf/1512.03385.pdf
    # normalization values taken from: https://github.com/Armour/pytorch-nn-practice/blob/master/utils/meanstd.py
    train_tx = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.50707516, 0.48654887, 0.44091784], [0.26733429, 0.25643846, 0.27615047]),
    ])
    val_tx = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.50707516, 0.48654887, 0.44091784], [0.26733429, 0.25643846, 0.27615047]),
    ])

    train = CIFAR100(args.data_root, train=True, transform=train_tx, download=True)
    val = CIFAR100(args.data_root, train=True, transform=val_tx)
    test = CIFAR100(args.data_root, train=False, transform=val_tx, download=True)

    train.targets = torch.tensor(train.targets)
    test.targets = torch.tensor(test.targets)
    if args.get_val:
        perm = torch.randperm(train.data.shape[0])
        n = int(args.val_pct * train.data.shape[0])
        val_x, val_y, train_x, train_y = train.data[perm[:n]], train.targets[perm[:n]], train.data[perm[n:]], train.targets[perm[n:]]
        train.data, train.targets = train_x, train_y
        val.data, val.targets = val_x, val_y

    train_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    val_ldr = DataLoader(val, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    test_ldr = DataLoader(test, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    return train_ldr, val_ldr, test_ldr


def get_omniglot(args: Namespace) -> Tuple[DataLoader, ...]:
    train = few_shot.Omniglot(args, split="train")

    test = few_shot.OmniglotCorruptTest(args) if args.corrupt_test \
        else few_shot.Omniglot(args, split="test")

    # shuffle deosn't matter because each index of the datsets are random tasks
    train_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    val_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    test_ldr = DataLoader(test, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    return train_ldr, val_ldr, test_ldr


def get_miniimagenet(args: Namespace) -> Tuple[DataLoader, ...]:
    train = few_shot.MiniImageNet(args, split="train")
    val = few_shot.MiniImageNet(args, split="val")

    test = few_shot.MiniImageNetCorruptTest(args) if args.corrupt_test \
        else few_shot.MiniImageNet(args, split="test")

    # shuffle deosn't matter because each index of the datsets are random tasks
    train_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    val_ldr = DataLoader(val, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    test_ldr = DataLoader(test, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True)
    return train_ldr, val_ldr, test_ldr


def get_gapped_sine(args: Namespace) -> Tuple[DataLoader, ...]:
    train = GappedSine(device=args.device, test=False)
    val = GappedSine(device=args.device, test=False)
    test = GappedSine(device=args.device, test=True)

    params = train.normalize()
    val.normalize(params)
    test.normalize(params)

    train_ldr = DataLoader(train, shuffle=True, batch_size=args.batch_size)
    val_ldr = DataLoader(val, shuffle=True, batch_size=args.batch_size)
    test_ldr = DataLoader(test, shuffle=False, batch_size=args.batch_size)
    return train_ldr, val_ldr, test_ldr


def get_few_shot_toy_gaussian(args: Namespace) -> Tuple[DataLoader, ...]:
    train, test = MetaGaussians(n_way=args.n_way, k_shot=args.k_shot, test_shots=args.test_shots, total_tasks=100, seed=args.seed), \
        MetaGaussians(n_way=args.n_way, k_shot=args.k_shot, total_tasks=5, seed=args.seed, test_shots=args.test_shots)

    train_ldr = DataLoader(train, batch_size=args.batch_size)
    test_ldr = DataLoader(test, batch_size=args.batch_size)
    return train_ldr, None, test_ldr


def get_few_shot_toy_moons(args: Namespace) -> Tuple[DataLoader, ...]:
    train, test = MetaMoons(k_shot=args.k_shot, test_shots=args.test_shots, total_tasks=100, seed=args.seed), MetaMoons(k_shot=args.k_shot, total_tasks=5, seed=args.seed)
    train_ldr = DataLoader(train, batch_size=args.batch_size)
    test_ldr = DataLoader(test, batch_size=args.batch_size)
    return train_ldr, None, test_ldr


def get_few_shot_toy_circles(args: Namespace) -> Tuple[DataLoader, ...]:
    train, test = MetaCircles(k_shot=args.k_shot, test_shots=args.test_shots, total_tasks=100, seed=args.seed), MetaCircles(k_shot=args.k_shot, total_tasks=5, seed=args.seed)
    train_ldr = DataLoader(train, batch_size=args.batch_size)
    test_ldr = DataLoader(test, batch_size=args.batch_size)
    return train_ldr, None, test_ldr


def get_few_shot_toy_all(args: Namespace) -> Tuple[DataLoader, ...]:

    def collate(lst: List[T]) -> List[T]:
        x_spt, y_spt, x_qry, y_qry = [], [], [], []
        for xs, ys, xq, yq in lst:
            x_spt.append(xs)
            y_spt.append(ys)
            x_qry.append(xq)
            y_qry.append(yq)
        return x_spt, y_spt, x_qry, y_qry

    train, test = MetaAll(args), MetaAll(args)
    train_ldr = DataLoader(train, batch_size=args.batch_size, collate_fn=collate)
    test_ldr = DataLoader(test, batch_size=args.batch_size, collate_fn=collate)
    return train_ldr, None, test_ldr


def get_moons(args: Namespace) -> Tuple[DataLoader, ...]:
    train, test = Moons(500 * 2, seed=args.seed), Moons(1000, seed=args.seed)
    train_ldr = DataLoader(train, batch_size=args.batch_size, shuffle=True)
    test_ldr = DataLoader(test, batch_size=args.batch_size, shuffle=True)
    return train_ldr, None, test_ldr


def get_gaussians(args: Namespace) -> Tuple[DataLoader, ...]:
    train, test = Gaussians(500 * 2, classes=10, seed=args.seed), Gaussians(100 * 10, classes=10, seed=args.seed)
    train_ldr = DataLoader(train, batch_size=args.batch_size, shuffle=True)
    test_ldr = DataLoader(test, batch_size=args.batch_size, shuffle=True)
    return train_ldr, None, test_ldr


def get_circles(args: Namespace) -> Tuple[DataLoader, ...]:
    train, test = Circles(500 * 10, seed=args.seed), Circles(1000, seed=args.seed)
    train_ldr = DataLoader(train, batch_size=args.batch_size, shuffle=True)
    test_ldr = DataLoader(test, batch_size=args.batch_size, shuffle=True)
    return train_ldr, None, test_ldr


def get_higgs(args: Namespace) -> Tuple[DataLoader, ...]:
    # val x and y will be reset in the below block of code
    TINY = False
    train, val, test = HiggsDataset(args.data_root, tiny=TINY), HiggsDataset(args.data_root, test=True, tiny=TINY), HiggsDataset(args.data_root, test=True, tiny=TINY)

    trainpoints = torch.randperm(train.x.size(0))
    n = int(train.x.size(0) * args.val_pct)

    val_x, val_y, train_x, train_y = train.x[trainpoints[:n]], train.y[trainpoints[:n]], train.x[trainpoints[n:]], train.y[trainpoints[n:]]

    train.x = train_x
    train.y = train_y
    val.x = val_x
    val.y = val_y

    return (
        DataLoader(train, batch_size=args.batch_size, shuffle=True),
        DataLoader(val, batch_size=args.batch_size, shuffle=True),
        DataLoader(test, batch_size=args.batch_size, shuffle=False)
    )


deref = {
    "mnist": get_mnist,
    "cifar10": get_cifar10,
    "cifar100": get_cifar100,
    "omniglot": get_omniglot,
    "miniimagenet": get_miniimagenet,
    "gapped-sine": get_gapped_sine,
    "few-shot-toy-circles": get_few_shot_toy_circles,
    "few-shot-toy-gaussian": get_few_shot_toy_gaussian,
    "few-shot-toy-moons": get_few_shot_toy_moons,
    "few-shot-toy-all": get_few_shot_toy_all,
    "toy-moons": get_moons,
    "toy-circles": get_circles,
    "toy-gaussians": get_gaussians,
    "higgs": get_higgs,
    **{v: lambda args: get_pbp_sets(args.dataset, args.data_root, args.run, args.val_pct, args.get_val, args.shifted) for v in pbp_sets},
}


def get_dataset(args: Namespace) -> Tuple[DataLoader, ...]:
    if args.dataset not in deref.keys():
        raise NotImplementedError(f"dataset: {args.dataset} is not implemented")

    return deref[args.dataset](args)


def plot_samples(x: T, args: Namespace) -> None:
    path = os.path.join("data", "examples")
    os.makedirs(path, exist_ok=True)
    grid = torchvision.utils.make_grid(x, nrows=5)
    torchvision.utils.save_image(grid, os.path.join(path, f"{args.dataset}-example.png"))
