import os
import unittest
from argparse import Namespace
from functools import partial

import torch
import torchvision  # type: ignore
from matplotlib import pyplot as plt  # type: ignore
from matplotlib.lines import Line2D  # type: ignore
from torch.utils.data import DataLoader
from utils import get_color

from data.few_shot import (MiniImageNet, MiniImageNetCorruptTest, Omniglot,
                           OmniglotCorruptTest)
from data.higgs import HiggsDataset
from data.particle_jets import AtlasJets, Jets
from data.toy_classification import Circles, Gaussians, Moons
from data.toy_meta import MetaCircles, MetaGaussians, MetaMoons
from data.toy_regression import GappedSine

DATA_ROOT = "/home/datasets"


class DatasetTests(unittest.TestCase):
    def test_dataset_numpy_corr_loaders(self) -> None:
        for name, DS in zip(["omniglot", "miniimagenet"], [OmniglotCorruptTest, MiniImageNetCorruptTest]):
            for n_way, k_shot in zip([5, 5, 20, 20], [1, 5, 1, 5]):
                args = Namespace(
                    data_root=DATA_ROOT,
                    ood_test=False,
                    dataset=name,
                    n_way=n_way,
                    k_shot=k_shot,
                    train_query_shots=k_shot,
                    val_query_shots=15,
                    run=0,
                    corrupt_test=True
                )

                print(f"numpy {name} nway: {n_way} kshot: {k_shot}")
                dataset = DS(args)
                spt_x, spt_y, qry_x, qry_y = dataset[0]
                print(f"support: {spt_x.size()} {spt_y.size()}")
                print(f"query: {qry_x.size()} {qry_y.size()}")

                ldr = DataLoader(dataset, batch_size=32, shuffle=True)
                for i, (spt_x, spt_y, qry_x, qry_y) in enumerate(ldr):
                    print(spt_x.size(), spt_y.size(), qry_x.size(), qry_y.size())
                    if i == 2:
                        break

                for i in range(10):
                    spt_x, _, qry_x, _ = dataset[i]
                    print(spt_x.size(), qry_x.size())
                    path = f"data/examples/{name}-corrupt-numpy/plain-n-{n_way}-k-{k_shot}"
                    os.makedirs(path, exist_ok=True)
                    grid = torchvision.utils.make_grid(spt_x, nrow=k_shot)
                    torchvision.utils.save_image(grid, os.path.join(path, f"support-{i}.png"))

                    grid = torchvision.utils.make_grid(qry_x, nrow=15 * 6)
                    torchvision.utils.save_image(grid, os.path.join(path, f"query-{i}.png"))

                for i in range(10):
                    spt_x, _, qry_x, _ = dataset[i]
                    _, ch, h, w = qry_x.size()
                    n_corruptions = qry_x.size(0) // (args.val_query_shots * n_way)

                    path = f"data/examples/{name}-corrupt-numpy-plot-by-level/plain-n-{n_way}-k-{k_shot}"
                    os.makedirs(path, exist_ok=True)

                    qry_x = qry_x.view(n_way, args.val_query_shots, n_corruptions, *qry_x.size()[1:])
                    splits = torch.split(qry_x, 1, dim=2)
                    splits = [v.reshape(-1, ch, h, w) for v in splits]

                    for j, corruption_split in enumerate(splits):
                        grid = torchvision.utils.make_grid(spt_x, nrow=k_shot)
                        torchvision.utils.save_image(grid, os.path.join(path, f"support-{i}.png"))

                        grid = torchvision.utils.make_grid(corruption_split, nrow=args.val_query_shots)
                        torchvision.utils.save_image(grid, os.path.join(path, f"query-{i}-level-{j}.png"))

    def test_fewshot_ood_test_dataloaders(self) -> None:
        for name, DS in zip(["omniglot", "miniimagenet"], [Omniglot, MiniImageNet]):
            for n_way, k_shot in zip([5, 5, 20, 20], [1, 5, 1, 5]):
                if n_way == 20 and name == "miniimagenet":
                    # there aren't enough classes in the miniimagenet dataset to handle 20 * 2 way for the query set.
                    continue

                args = Namespace(
                    data_root=DATA_ROOT,
                    ood_test=True,
                    dataset=name,
                    n_way=n_way,
                    k_shot=k_shot,
                    train_query_shots=k_shot,
                    val_query_shots=k_shot,
                    run=0,
                )

                print(f"numpy ood {name} nway: {n_way} kshot: {k_shot}")
                dataset = DS(args, split="test")
                spt_x, spt_y, qry_x, qry_y = dataset[0]
                print(f"support: {spt_x.size()} {spt_y.size()}")
                print(f"query: {qry_x.size()} {qry_y.size()}")

                ldr = DataLoader(dataset, batch_size=32, shuffle=True)
                for i, (spt_x, spt_y, qry_x, qry_y) in enumerate(ldr):
                    print(spt_x.size(), spt_y.size(), qry_x.size(), qry_y.size())
                    if i == 2:
                        break

                for i in range(10):
                    spt_x, _, qry_x, _ = dataset[i]
                    print(spt_x.size(), qry_x.size())
                    path = f"data/examples/{name}-ood-test-numpy/plain-n-{n_way}-k-{k_shot}"
                    os.makedirs(path, exist_ok=True)
                    grid = torchvision.utils.make_grid(spt_x, nrow=k_shot)
                    torchvision.utils.save_image(grid, os.path.join(path, f"support-{i}.png"))

                    grid = torchvision.utils.make_grid(qry_x, nrow=k_shot)
                    torchvision.utils.save_image(grid, os.path.join(path, f"query-{i}.png"))

    def test_dataset_numpy_loaders(self) -> None:
        for name, DS in zip(["omniglot", "miniimagenet"], [Omniglot, MiniImageNet]):
            for n_way, k_shot in zip([5, 5, 20, 20], [1, 5, 1, 5]):
                args = Namespace(
                    data_root=DATA_ROOT,
                    ood_test=False,
                    dataset=name,
                    run=0,
                    n_way=n_way,
                    k_shot=k_shot,
                    train_query_shots=k_shot,
                    val_query_shots=k_shot,
                    train_test_shots=k_shot
                )

                print(f"numpy {name} nway: {n_way} kshot: {k_shot}")
                dataset = DS(args, split="train")
                spt_x, spt_y, qry_x, qry_y = dataset[0]
                print(f"support: {spt_x.size()} {spt_y.size()}")
                print(f"query: {qry_x.size()} {qry_y.size()}")

                ldr = DataLoader(dataset, batch_size=32, shuffle=True)
                for i, (spt_x, spt_y, qry_x, qry_y) in enumerate(ldr):
                    print(spt_x.size(), spt_y.size(), qry_x.size(), qry_y.size())
                    if i == 2:
                        break

                for i in range(10):
                    spt_x, _, qry_x, _ = dataset[i]
                    print(spt_x.size(), qry_x.size())
                    path = f"data/examples/{name}-numpy/plain-n-{n_way}-k-{k_shot}"
                    os.makedirs(path, exist_ok=True)
                    grid = torchvision.utils.make_grid(spt_x, nrow=k_shot)
                    torchvision.utils.save_image(grid, os.path.join(path, f"support-{i}.png"))

                    grid = torchvision.utils.make_grid(qry_x, nrow=15)
                    torchvision.utils.save_image(grid, os.path.join(path, f"query-{i}.png"))

                del dataset

                print(f"numpy {name} WITH OOD CLASS TEST SET: nway: {n_way} kshot: {k_shot}")
                args.ood_test = True
                dataset = DS(args, split="train")
                spt_x, spt_y, qry_x, qry_y = dataset[0]
                print(f"support: {spt_x.size()} {spt_y.size()}")
                print(f"query: {qry_x.size()} {qry_y.size()}")

                ldr = DataLoader(dataset, batch_size=32, shuffle=True)
                for i, (spt_x, spt_y, qry_x, qry_y) in enumerate(ldr):
                    print(spt_x.size(), spt_y.size(), qry_x.size(), qry_y.size())
                    if i == 2:
                        break

                for i in range(10):
                    spt_x, _, qry_x, _ = dataset[i]
                    print(spt_x.size(), qry_x.size())
                    path = f"data/examples/{name}-numpy/plain-ood-class-metatest-n-{n_way}-k-{k_shot}"
                    os.makedirs(path, exist_ok=True)
                    grid = torchvision.utils.make_grid(spt_x, nrow=k_shot)
                    torchvision.utils.save_image(grid, os.path.join(path, f"support-{i}.png"))

                    grid = torchvision.utils.make_grid(qry_x, nrow=15)
                    torchvision.utils.save_image(grid, os.path.join(path, f"query-{i}.png"))

    def test_dataset_toy_classification(self) -> None:
        cols = 3
        for ds_name, dataset_class in zip(["moons", "circles", "gaussian"], [Moons, Circles, Gaussians]):
            print(ds_name)
            fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(cols * 7, 6))
            for i, ax in enumerate(axes):
                ds = dataset_class(100, seed=i)
                loader = DataLoader(ds, batch_size=100)
                for (x, y) in loader:
                    break

                ax.scatter(x[:, 0], x[:, 1], c=[get_color(v.item()) for v in y], s=50, edgecolors=(0, 0, 0, 0.5), linewidths=2.0)
                ax.set_title(f"dataset sample: {i}")

            path = os.path.join("data", "examples", f"{ds_name}")
            os.makedirs(path, exist_ok=True)
            fig.tight_layout()
            fig.savefig(os.path.join(path, "example.pdf"))
            fig.savefig(os.path.join(path, "example.png"))

    def test_dataset_toy_regression(self) -> None:
        train, test = GappedSine(torch.device("cpu")), GappedSine(torch.device("cpu"), test=True)
        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 6))

        ax.scatter(train.x.numpy(), train.y.numpy(), label="train")
        sort_idx = torch.argsort(test.x.squeeze(-1))
        print(test.x.size())
        ax.plot(test.x[sort_idx], test.y[sort_idx], label="test")
        ax.legend()

        path = os.path.join("data", "examples", "gapped-sine")
        os.makedirs(path, exist_ok=True)
        fig.savefig(os.path.join(path, "example.pdf"))
        fig.savefig(os.path.join(path, "example.png"))

    def test_dataset_toy_classification_few_shot(self) -> None:
        cols = 3
        for ds_name, ds_class in zip(
            ["meta-moons", "meta-circles", "meta-gaussians"],
            [partial(MetaMoons, seed=1), partial(MetaCircles, seed=1), partial(MetaGaussians, seed=1, k_shot=5, test_shots=15)]
        ):
            ds = ds_class()  # type: ignore
            fig, axes = plt.subplots(nrows=1, ncols=cols, figsize=(cols * 7, 6))
            for i, ax in enumerate(axes):
                xtr, ytr, xte, yte = ds[0]

                # this sample will be form a different task, but we are only taking the uniform noise so it is ok
                ax.scatter(xtr[:, 0], xtr[:, 1], c=[get_color(v.item()) for v in ytr], s=50, edgecolors=(0, 0, 0, 0.5), linewidths=2.0)
                ax.scatter(xte[:, 0], xte[:, 1], c=[get_color(v.item()) for v in yte], marker='*', s=20)
                ax.set_title(f"task: {i}")
                if i == cols - 1:
                    legend_elements = [
                        Line2D([0], [0], marker='o', color='w', label='train', markerfacecolor='black', markersize=10),
                        Line2D([0], [0], marker='*', color='w', label='test', markerfacecolor='black', markersize=10),
                    ]
                    ax.legend(handles=legend_elements)

            path = os.path.join("data", "examples", f"{ds_name}")
            os.makedirs(path, exist_ok=True)
            fig.tight_layout()
            fig.savefig(os.path.join(path, "metatrain-example.pdf"))
            fig.savefig(os.path.join(path, "metatrain-example.png"))

    def test_higgs_dataset(self) -> None:
        # ran the train and test set to verify they are correct, tiny is set to true to avoid a long running test
        train, test = HiggsDataset(DATA_ROOT, test=False, tiny=True), HiggsDataset(DATA_ROOT, test=True, tiny=True)

        xtrain, ytrain = train[0]
        xtest, ytest = test[0]

        self.assertEqual(xtrain.size(0), 28)
        self.assertEqual(xtest.size(0), 28)
        self.assertEqual(len(ytrain.size()), 0)
        self.assertEqual(len(ytest.size()), 0)

    def test_jets_dataset(self) -> None:
        for s in ["val", "test"]:
            ds = Jets(DATA_ROOT, split=s)
            x, y = ds[0]
            set, dim = x.size()

            self.assertEqual(set, 100)
            self.assertEqual(dim, 6)
            self.assertEqual(y.size(0), 1)

    def test_atlas_jets_dataset(self) -> None:
        ds = AtlasJets(DATA_ROOT, split="train")
        for i in range(100):
            x, y = ds[i]
            print(x.size(), y.size())
