import os
from argparse import Namespace
from typing import Any, List, Tuple, Union

import h5py  # type: ignore
import numpy as np  # type: ignore
import torch
from torch.utils.data import Dataset

# these are the corruptions introduced in https://arxiv.org/abs/1903.12261
CORRUPTIONS = [
    "brightness",
    "contrast",
    "defocus_blur",
    "elastic_transform",
    "fog",
    "frost",
    "gaussian_blur",
    "gaussian_noise",
    "glass_blur",
    "impulse_noise",
    "jpeg_compression",
    "pixelate",
    "saturate",
    "shot_noise",
    "spatter",
    "speckle_noise",
    "zoom_blur",
]


T = torch.Tensor
LT = Union[List, T]


class NumpyDataset(Dataset):
    def __init__(self, args: Namespace, split: str = "none"):
        """ood_test: if set to true, the test sets have random OOD classes (not seen during training)"""
        self.root = args.data_root
        self.split = split
        self.n_way = args.n_way
        self.k_shot = args.k_shot
        self.test_shots = args.train_query_shots if split == "train" else args.val_query_shots
        self.ood_test = args.ood_test

        self.data: Any
        self.dim: int
        self.ch: int
        self.n_classes: int
        self.n_per_class: int
        self.test_episodes: int
        self.out_xform: Any

    def __len__(self) -> int:
        """
        this is not really important for the training set since every task is sampled randomly, and we measure total tasks not epochs.
        For training, this is intended to be set to a round number to decide when to do validation.
        For testing, this should be some data length that is standard for the given test set.. Using values from prototypical networks
        """
        if self.split == "train":
            return 100000  # type: ignore
        elif self.split in ["val", "test"]:
            return self.test_episodes
        else:
            raise NotImplementedError("split not implemented")

    def random_classes_and_shots(self, n_way: int, k_shot: int, test_shots: int) -> Tuple[Any, ...]:
        classes = np.random.choice(self.n_classes, size=n_way, replace=False)
        return (
            classes,
            np.array([np.random.choice(self.n_per_class, size=self.k_shot + self.test_shots, replace=False) for _ in range(n_way)])
        )

    def get_empty(self, qry_size_factor: int = 1) -> Tuple[T, T, T, T]:
        return (
            torch.zeros(self.n_way * self.k_shot, self.ch, self.dim, self.dim),
            torch.zeros(self.n_way * self.k_shot),
            torch.zeros(self.n_way * self.test_shots * qry_size_factor, self.ch, self.dim, self.dim),
            torch.zeros(self.n_way * self.test_shots * qry_size_factor),
        )

    def getitem_ood_test(self, i: int) -> Tuple[T, T, T, T]:
        spt_x, spt_y, qry_x, qry_y = self.get_empty(qry_size_factor=2)
        classes, idx = self.random_classes_and_shots(self.n_way * 2, self.k_shot, self.test_shots)

        # get a test set with random classes in it
        for i, cl in enumerate(classes[:self.n_way]):  # get the first n classes
            train_idx, test_idx = idx[i, :self.k_shot], idx[i, self.k_shot:]
            spt_x[(i * self.k_shot) : ((i + 1) * self.k_shot)] = self.data[cl, train_idx]
            spt_y[(i * self.k_shot) : ((i + 1) * self.k_shot)] = i

            qry_x[i * self.test_shots : (i + 1) * self.test_shots] = self.data[cl, test_idx]
            qry_y[i * self.test_shots : (i + 1) * self.test_shots] = i

        filled_idx = classes[:self.n_way].shape[0]  # store the old i and keep it as a shift factor for the new indices
        for i, cl in enumerate(classes[self.n_way : self.n_way * 2]):  # get a different set of classes for ood
            test_idx = idx[i, self.k_shot : self.k_shot + self.test_shots]
            qry_x[((i + filled_idx) * self.test_shots) : ((i + filled_idx + 1) * self.test_shots)] = self.data[cl, test_idx]
            qry_y[((i + filled_idx) * self.test_shots) : ((i + filled_idx + 1) * self.test_shots)] = i

        return self.out_xform(spt_x), spt_y.long(), self.out_xform(qry_x), qry_y.long()

    def getitem(self, i: int) -> Tuple[T, T, T, T]:
        spt_x, spt_y, qry_x, qry_y = self.get_empty()

        classes, idx = self.random_classes_and_shots(self.n_way, self.k_shot, self.test_shots)

        # get the regular test set
        for i, cl in enumerate(classes):
            train_idx, test_idx = idx[i, :self.k_shot], idx[i, self.k_shot:]

            spt_x[i * self.k_shot : (i + 1) * self.k_shot] = self.data[cl, train_idx]
            spt_y[i * self.k_shot : (i + 1) * self.k_shot] = i

            qry_x[i * self.test_shots : (i + 1) * self.test_shots] = self.data[cl, test_idx]
            qry_y[i * self.test_shots : (i + 1) * self.test_shots] = i

        return self.out_xform(spt_x), spt_y.long(), self.out_xform(qry_x), qry_y.long()

    def __getitem__(self, i: int) -> Tuple[T, T, T, T]:
        if self.ood_test:
            return self.getitem_ood_test(i)
        return self.getitem(i)


class Omniglot(NumpyDataset):
    def __init__(self, args: Namespace, split: str = "none"):
        super().__init__(args, split=split)

        # these will always be the test indices in the image folders
        self.data = np.load(os.path.join(self.root, "omniglot", "omniglot.npy"))
        self.stats = np.loadtxt(os.path.join(self.root, "omniglot", "splits", f"{args.run}-stats.txt"))

        split_file = {
            "train": os.path.join(self.root, "omniglot", "splits", f"{args.run}-train.txt"),
            "test": os.path.join(self.root, "omniglot", "splits", f"{args.run}-test.txt")
        }

        self.split_idx = np.loadtxt(split_file[split]).astype(int)
        data = self.data[self.split_idx]
        for rot in [1, 2, 3]:
            rotated = np.rot90(self.data[self.split_idx], k=rot, axes=(3, 4))
            data = np.concatenate((data, rotated))

        # self.data = torch.from_numpy((data - self.stats[0]) / self.stats[1])
        self.data = torch.from_numpy(data)

        self.dim = 28
        self.ch = 1
        self.n_classes = self.data.shape[0]
        self.n_per_class = 20
        self.test_episodes = 1000
        self.out_xform = lambda x: 1 - (x / 255.)


class MiniImageNet(NumpyDataset):
    def __init__(self, args: Namespace, split: str = "none"):
        """ood_test: if set to true, the test sets have random OOD classes (not seen during training)"""
        super().__init__(args, split=split)

        data = []
        datasets = h5py.File(os.path.join(self.root, "miniimagenet", f"{split}_data.hdf5"), "r")["datasets"]
        for ds in datasets:
            data.append(datasets[ds])

        mu, std = np.loadtxt(os.path.join(self.root, "miniimagenet", "mu.txt")), np.loadtxt(os.path.join(self.root, "miniimagenet", "std.txt"))
        mu, std = mu.reshape(1, 1, 3, 1, 1), std.reshape(1, 1, 3, 1, 1)

        self.data = np.stack(data)
        self.data = torch.from_numpy(np.transpose(self.data, (0, 1, 4, 2, 3)))

        self.dim = 84
        self.ch = 3
        self.n_classes = self.data.shape[0]
        self.n_per_class = 600
        self.test_episodes = 600
        self.out_xform = lambda x: x / 255.

        data, datasets = [], []  # free up the memory
        # self.data = (self.data - mu) / std
        # self.data = self.data / 255.


def get_empty_corrupt(n_way: int, k_shot: int, test_shots: int, ch: int, dim: int) -> Tuple[T, T, T, T]:
    """same as get empty but lets us define the sizes so we can make room for the corrupt dtaset sizes"""
    return (
        torch.zeros(n_way * k_shot, ch, dim, dim),
        torch.zeros(n_way * k_shot),
        torch.zeros(n_way * test_shots, ch, dim, dim),
        torch.zeros(n_way * test_shots),
    )


class OmniglotCorruptTest(NumpyDataset):
    def __init__(self, args: Namespace) -> None:
        super().__init__(args, split="test")

        self.stats = np.loadtxt(os.path.join(self.root, "omniglot", "splits", f"{args.run}-stats.txt"))
        split_file = {
            "train": os.path.join(self.root, "omniglot", "splits", f"{args.run}-train.txt"),
            "test": os.path.join(self.root, "omniglot", "splits", f"{args.run}-test.txt")
        }

        self.split_idx = np.loadtxt(split_file["test"]).astype(int)
        data_path = os.path.join(self.root, "corruptions", "omniglot-c", "npy")

        data = []
        for i in self.split_idx:
            data.append(np.load(os.path.join(data_path, f"{i}.npy")))
        self.data = np.array(data)

        # self.data = torch.from_numpy((data - self.stats[0]) / self.stats[1])
        self.data = torch.from_numpy(self.data)
        self.data = self.data.transpose(1, 3)

        self.n_classes = self.data.shape[0]
        self.n_corruptions = self.data.shape[2]
        self.levels = self.data.shape[3]

        self.dim = 28
        self.ch = 1
        self.n_per_class = 20
        self.test_episodes = 1000
        self.out_xform = lambda x: 1 - (x / 255.)

    def getitem(self, i: int) -> Tuple[T, T, T, T]:
        spt_x, spt_y, qry_x, qry_y = get_empty_corrupt(self.n_way, self.k_shot, self.test_shots * 6, 1, 28)

        classes, idx = self.random_classes_and_shots(self.n_way, self.k_shot, self.test_shots)
        corrs = np.random.choice(self.n_corruptions, size=classes.shape[0], replace=True)

        # get the regular test set
        for i, cl in enumerate(classes):
            # hardcoded 15 here because we need to choose the number of shots according to the regular dataset, but it
            # will just be multiplied by a factor of 6 for 6 corruption intensities
            train_idx, test_idx = idx[i, :self.k_shot], idx[i, self.k_shot : self.k_shot + 15]
            c = corrs[i]

            tr, te = self.data[cl, 0, c, train_idx], self.data[cl, :, c, test_idx]

            tr = np.reshape(tr, (-1, 1, tr.shape[-2], tr.shape[-1]))
            te = np.reshape(te, (-1, 1, te.shape[-2], te.shape[-1]))

            spt_x[(i * self.k_shot) : ((i + 1) * self.k_shot)] = tr
            spt_y[(i * self.k_shot) : ((i + 1) * self.k_shot)] = i

            # the 6 factor expansion comes from the 6 different corruption intensities in the corrupted set
            qry_x[(i * 6 * self.test_shots) : ((i + 1) * 6 * self.test_shots)] = te
            qry_y[(i * 6 * self.test_shots) : ((i + 1) * 6 * self.test_shots)] = i

        return self.out_xform(spt_x), spt_y.long(), self.out_xform(qry_x), qry_y.long()


class MiniImageNetCorruptTest(NumpyDataset):
    def __init__(self, args: Namespace) -> None:
        super().__init__(args, split="test")

        self.path = os.path.join(f"{self.n_way}-way", f"{self.k_shot}-shot", f"{self.test_shots}-testshot")
        self.data = torch.from_numpy(np.load(os.path.join(self.root, "corruptions", "miniimagenet-c", "test.npy")))
        self.data = self.data.transpose(5, 3)
        # data is in the shape of (17, 6, 12000, 84, 84, 3) --> (corruptions, intensities, 20 clases * 600 instances, dim, dim, ch)
        # after transpose --> (17, 6, 12000, 3, 84, 84)

        mu, std = np.loadtxt(os.path.join(self.root, "miniimagenet", "mu.txt")), np.loadtxt(os.path.join(self.root, "miniimagenet", "std.txt"))
        mu, std = mu.reshape(1, 1, 3, 1, 1), std.reshape(1, 1, 3, 1, 1)

        # self.data = (self.data - mu) / std

        self.n_classes = int(self.data.shape[2] / 600)
        self.n_corruptions = self.data.shape[0]
        self.levels = self.data.shape[2]

        self.dim = 84
        self.ch = 3
        self.n_per_class = 600
        self.test_episodes = 600
        self.out_xform = lambda x: x / 255.

    def getitem(self, i: int) -> Tuple[T, T, T, T]:
        spt_x, spt_y, qry_x, qry_y = get_empty_corrupt(self.n_way, self.k_shot, self.test_shots * 6, 3, 84)

        classes, idx = self.random_classes_and_shots(self.n_way, self.k_shot, self.test_shots)
        corrs = np.random.choice(self.n_corruptions, size=classes.shape[0], replace=True)

        # get the regular test set
        for i, cl in enumerate(classes):
            # hardcoded 15 here because we need to choose the number of shots according to the regular dataset, but it
            # will just be multiplied by a factor of 6 for 6 corruption intensities
            train_idx, test_idx = idx[i, :self.k_shot], idx[i, self.k_shot : self.k_shot + 15]

            c = corrs[i]
            tr_idx, te_idx = cl * 600 + train_idx, cl * 600 + test_idx
            tr, te = self.data[c, 0, tr_idx], self.data[c, :, te_idx]
            # tr -> (shots, 3, 84, 84), te -> (6, test_shots, 3, 84, 84)

            te = np.reshape(te, (-1, 3, te.shape[-2], te.shape[-1]))
            spt_x[(i * self.k_shot) : ((i + 1) * self.k_shot)] = tr
            spt_y[(i * self.k_shot) : ((i + 1) * self.k_shot)] = i

            qry_x[(i * 6 * self.test_shots) : ((i + 1) * 6 * self.test_shots)] = te
            qry_y[(i * 6 * self.test_shots) : ((i + 1) * 6 * self.test_shots)] = i

        return self.out_xform(spt_x), spt_y.long(), self.out_xform(qry_x), qry_y.long()
