from torch.utils.data import DataLoader, Dataset
from datasets import kinetics, mmnist, imagenet, fashion_mnist
import pytest
import pytorch_lightning as pl
from pytorch_lightning.trainer import supporters
import torch

from datasets.shapes import (
    Shapes,
    ShapesCanonicalDataModule,
    ShapesDataModule,
    ShapesPairs,
    ShapesPairsDataModule,
    ShapesPairsFixedEval,
)
from datasets.shapes_generation import attributes
from typing import Dict, Set


class TestKinetics:
    def test_dataloaders(self):
        batch_size = 8
        data_module = kinetics.KineticsDataModule(batch_size=batch_size)
        train_loader = data_module.train_dataloader()
        _check_video_loader_batch_size(train_loader, batch_size)
        val_loader = data_module.val_dataloader()
        _check_video_loader_batch_size(val_loader, batch_size)


class TestImageMMNISTSimCLR:
    @pytest.mark.skip(reason="skipping MMNIST test")
    def test_dataloaders(self):
        batch_size = 8
        data_module = mmnist.ImageMMNISTSimCLRDataModule(batch_size=batch_size)
        train_loader = data_module.train_dataloader()
        val_loader = data_module.val_dataloader()
        self._check_loader_batch_size(train_loader, batch_size)
        self._check_loader_batch_size(val_loader, batch_size)

    def _check_loader_batch_size(self, loader: DataLoader, batch_size: int):
        img_size = 224
        batch = next(iter(loader))
        (x1, x2, x), y = batch
        correct_input_shape = (batch_size, 3, img_size, img_size)
        assert x1.shape == correct_input_shape
        assert x2.shape == correct_input_shape
        assert x.shape == correct_input_shape
        assert y.shape == (batch_size,)


class TestImageNet:
    def test_dataloader(self):
        batch_size = 8
        image_size = 224
        dm = imagenet.ImageNetDataModule(batch_size=batch_size, image_size=224)
        batch = next(iter(dm.train_dataloader()))
        x, y = batch
        assert x.shape == (batch_size, 3, image_size, image_size)
        assert y.shape == (batch_size,)

    def test_simclr_dataloader(self):
        batch_size = 8
        image_size = 224
        dm = imagenet.ImageNetSimCLRDataModule(batch_size=batch_size, image_size=224)
        batch = next(iter(dm.train_dataloader()))
        (x1, x2, x), y = batch
        assert x.shape == (batch_size, 3, image_size, image_size)
        assert x1.shape == (batch_size, 3, image_size, image_size)
        assert x2.shape == (batch_size, 3, image_size, image_size)
        assert y.shape == (batch_size,)


class TestYouTube8M:
    pass


def _check_video_loader_batch_size(loader: DataLoader, batch_size: int):
    batch = next(iter(loader))
    assert len(batch["label"]) == batch_size
    assert batch["video"].shape[0] == batch_size


class TestUnbalancedFashionMNIST:
    IMAGE_SIZE = 28

    def test_dataset(self):
        ds = fashion_mnist.UnbalancedRotatedFashionMNISTDataset(rotate_every_n_images=1)
        x, y = ds[0]
        assert x.shape == (1, self.IMAGE_SIZE, self.IMAGE_SIZE)
        assert type(y) is int

    def test_dataset_all_rotated_indices(self):
        ds = fashion_mnist.UnbalancedRotatedFashionMNISTDataset(rotate_every_n_images=1)
        rotated_indices = ds.get_rotated_indices()
        assert rotated_indices == list(range(len(ds)))

    def test_dataset_half_rotated_indices(self):
        ds = fashion_mnist.UnbalancedRotatedFashionMNISTDataset(rotate_every_n_images=2)
        rotated_indices = ds.get_rotated_indices()
        assert rotated_indices == list(range(len(ds)))[::2]

        unrotated_indices = ds.get_unrotated_indices()
        assert unrotated_indices == list(range(len(ds)))[1::2]

    @pytest.mark.slow
    def test_datamodule(self):
        batch_size = 8
        dm = fashion_mnist.UnbalancedRotatedFashionMNISTDataModule(
            rotate_every_n_train_images=2, batch_size=batch_size
        )
        dm.setup()
        dl = dm.train_dataloader()
        dl_val = dm.val_dataloader()
        assert type(dl) is DataLoader
        assert type(dl_val) is DataLoader
        x_shape = (batch_size, 1, self.IMAGE_SIZE, self.IMAGE_SIZE)
        y_shape = (batch_size,)

        self._check_shape(dl, x_shape=x_shape, y_shape=y_shape)
        self._check_shape(dl_val, x_shape=x_shape, y_shape=y_shape)

    def _check_shape(self, dl, x_shape=(32, 1, 28, 28), y_shape=(32,)):
        x, y = next(iter(dl))
        assert x.shape == x_shape
        assert y.shape == y_shape


class TestFashionMNIST:
    IMAGE_SIZE = 28
    TOTAL_TRAIN_SIZE = 60000

    def test_dataset_setup(self):
        batch_size = 8
        angles = [0.0, 90.0]
        portion_of_train_to_rotate = 0.5
        val_percent = 0.2

        dm = fashion_mnist.RotatedFashionMNISTDataModule(
            batch_size=batch_size,
            portion_of_train_to_rotate=portion_of_train_to_rotate,
            angles=angles,
        )
        dm.setup()
        # would change for different portion
        expected_train_size = ((len(angles) + 1) * self.TOTAL_TRAIN_SIZE * 0.5) * (
            1 - val_percent
        )

        assert len(dm.train_set) == expected_train_size

    def test_compute_split(self):
        dm = fashion_mnist.RotatedFashionMNISTDataModule()
        ds = list(range(8))
        assert dm._compute_split(ds, percent=0.5) == [4, 4]

    def test_uneven_compute_split(self):
        dm = fashion_mnist.RotatedFashionMNISTDataModule()
        ds = list(range(100))
        assert dm._compute_split(ds, percent=0.8) == [80, 20]

    def test_dataloaders(self):
        batch_size = 8
        dm = fashion_mnist.RotatedFashionMNISTDataModule(batch_size=batch_size)
        dm.setup()

        train_loader = dm.train_dataloader()
        x, y = next(iter(train_loader))
        x_shape = (batch_size, 1, self.IMAGE_SIZE, self.IMAGE_SIZE)
        assert x.shape == x_shape
        y_shape = (batch_size,)
        assert y.shape == y_shape


class TestShapes:
    BATCH_SIZE = 2
    DATA_DIR = "DATADIR/datasets/shapes_renderings_overlapping_small"
    IMG_SIZE = 224

    @pytest.fixture(scope="class")
    def dm(self) -> pl.LightningDataModule:
        dm = ShapesDataModule(batch_size=self.BATCH_SIZE, data_dir=self.DATA_DIR)
        dm.setup()
        return dm

    def test_num_classes(self):
        dm_1k_classes = ShapesDataModule(
            batch_size=self.BATCH_SIZE,
            data_dir=self.DATA_DIR,
            use_imagenet_classes=True,
        )
        dm_1k_classes.setup()
        batch = next(iter(dm_1k_classes.train_dataloader()))["train_canonical"]
        x, y, fov = batch
        # check at least one of the class idx is larger than 20
        assert (y > 20).any().item()
        dm_15_classes = ShapesDataModule(
            batch_size=self.BATCH_SIZE,
            data_dir=self.DATA_DIR,
            use_imagenet_classes=False,
        )
        dm_15_classes.setup()
        batch = next(iter(dm_15_classes.train_dataloader()))["train_canonical"]
        x, y, fov = batch
        # check all class idx < 15 (total synsets in fov.csv file for the small dataset)
        assert (y < 15).any().item()

    def test_dataset_sample(self, dm):
        instance_ids = dm.fov_df.index.get_level_values(0).unique()[:2].tolist()
        dataset = Shapes(data_dir=self.DATA_DIR, instance_ids=instance_ids)
        x, y, fov = dataset[1]
        assert x.shape == (3, self.IMG_SIZE, self.IMG_SIZE)
        assert type(y) is int
        assert "pose" in fov

    def test_views(self):
        num_views = 9
        views = attributes.Views(view_start=0.0, view_end=360.0, num_views=num_views)
        angles = views.generate()
        assert len(angles) == (num_views * 2 - 1)

    def test_datamodule_setup(self, dm):
        assert len(dm.fov_df) > 0
        assert len(dm.train_canonical) > 0
        assert len(dm.diverse_2d_train_canonical) > 0
        assert len(dm.diverse_3d_train_canonical) > 0
        assert len(dm.train_canonical) * len(dm.views) == len(
            dm.diverse_2d_train_canonical
        ) + len(dm.diverse_3d_train_canonical)

    def test_train_dataloaders(self, dm):
        train_loaders = dm.train_dataloader()
        assert isinstance(train_loaders, supporters.CombinedLoader)
        self._check_train_loader_shapes(train_loaders, self.BATCH_SIZE)

    def test_val_test_dataloaders(self, dm):
        val_loaders = dm.val_dataloader()
        assert len(val_loaders) > 0
        assert len(val_loaders) == len(dm.val_loader_names)

        test_loaders = dm.test_dataloader()
        assert len(test_loaders) > 0
        assert len(test_loaders) == len(dm.test_loader_names)

    def _check_train_loader_shapes(
        self, dataloader: supporters.CombinedLoader, batch_size: int
    ):
        batch = next(iter(dataloader))
        for loader_name in batch:
            assert type(loader_name) is str
            x, y, fov = batch[loader_name]
            assert x.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert y.shape == (self.BATCH_SIZE,)
            assert "pose" in fov
            assert len(fov["synset"]) == batch_size
            assert len(fov["pose"][0]) == batch_size


class TestShapesCanonical(TestShapes):
    BATCH_SIZE = 2
    DATA_DIR = "DATADIR/datasets/shapes_renderings_overlapping_small"
    IMG_SIZE = 224

    @pytest.fixture(scope="class")
    def dm(self) -> pl.LightningDataModule:
        dm = ShapesCanonicalDataModule(
            batch_size=self.BATCH_SIZE, data_dir=self.DATA_DIR
        )
        dm.setup()
        return dm


class TestShapesPairsDataModule(TestShapes):
    BATCH_SIZE = 2
    DATA_DIR = "DATADIR/datasets/shapes_renderings_overlapping_small"
    IMG_SIZE = 224

    @pytest.fixture(scope="class")
    def dm(self) -> pl.LightningDataModule:
        dm = ShapesPairsDataModule(batch_size=self.BATCH_SIZE, data_dir=self.DATA_DIR)
        dm.setup()
        return dm

    def _check_train_loader_shapes(
        self, dataloader: supporters.CombinedLoader, batch_size: int
    ):
        batch = next(iter(dataloader))
        for loader_name in batch:
            assert type(loader_name) is str
            (x1_a1, x1_a2, x1_o), (x2_a1, x2_a2, x2_o), y, fov1, fov2, delta = batch[
                loader_name
            ]
            assert x1_a1.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert x1_a2.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert x2_a1.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert x2_a2.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert x1_o.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert x2_o.shape == (self.BATCH_SIZE, 3, self.IMG_SIZE, self.IMG_SIZE)
            assert y.shape == (self.BATCH_SIZE,)

            self._check_fov(fov1, batch_size)
            self._check_fov(fov2, batch_size)

    def _check_fov(self, fov: dict, batch_size: int):
        assert "pose" in fov
        assert len(fov["synset"]) == batch_size
        assert len(fov["pose"][0]) == batch_size

    def test_datamodule_setup(self, dm):
        assert len(dm.fov_df) > 0
        assert len(dm.train_canonical) > 0
        assert len(dm.train_diverse_2d) > 0
        assert len(dm.train_diverse_3d) > 0
        assert len(dm.val_diverse_2d) > 0
        assert len(dm.val_diverse_3d) > 0

    def test_val_test_dataloaders(self, dm):
        val_loaders = dm.val_dataloader()
        assert len(val_loaders) > 0
        assert len(val_loaders) == len(dm.val_loader_names)
        # skip test_loaders since pairs doesn't have nay


class TestShapesPairsDatasets:
    BATCH_SIZE = 2
    DATA_DIR = "DATADIR/datasets/shapes_renderings_overlapping_small"
    INSTANCE_IDS = [
        "139478e7e85aabf27274021d5552b63f",
        "ff9c1754252b9ebf73c7253ec9acd58b",
        "14fa15f31d713b7153b838b6058a8d95",
        "179d23a446719d27592ecd319dfd8c5d",
        "18ff360b39e5306272c797c96ca37495",
        "1c2267e8b8a099a47457e76661a399e9",
    ]
    IMG_SIZE = 224

    def test_pairs_instantiation(self):
        ds = ShapesPairs(data_dir=self.DATA_DIR, instance_ids=self.INSTANCE_IDS)
        assert isinstance(ds, Dataset)

    def test_fixed_pairs_instantiation(self):
        poses = [(0.0, 0.0, 0.0), (40.0, 40.0, 40.0)]
        ds_fixed = ShapesPairsFixedEval(
            data_dir=self.DATA_DIR,
            instance_ids=self.INSTANCE_IDS,
            num_pairs_per_instance=3,
            poses=poses,
        )
        assert isinstance(ds_fixed, Dataset)

    def test_length(self):
        ds = ShapesPairs(data_dir=self.DATA_DIR, instance_ids=self.INSTANCE_IDS)
        assert len(ds) == len(self.INSTANCE_IDS)

    def test_pair_counts(self):
        num_pairs_per_instance = 3
        poses = [(0.0, 0.0, 0.0), (40.0, 40.0, 40.0)]
        ds_fixed = ShapesPairsFixedEval(
            data_dir=self.DATA_DIR,
            instance_ids=self.INSTANCE_IDS,
            num_pairs_per_instance=num_pairs_per_instance,
            poses=poses,
        )
        assert isinstance(ds_fixed, Dataset)
        assert (
            len(ds_fixed.paired_indices)
            == len(self.INSTANCE_IDS) * num_pairs_per_instance
        )

    def test_pairs_shapes(self):
        ds = ShapesPairs(
            data_dir=self.DATA_DIR,
            instance_ids=self.INSTANCE_IDS,
            poses=[(0.0, 0.0, 0.0), (40.0, 40.0, 40.0)],
        )
        self._check_sample_shape(ds)

    def test_pairs_fixed_shapes(self):
        ds = ShapesPairsFixedEval(
            data_dir=self.DATA_DIR,
            instance_ids=self.INSTANCE_IDS,
            poses=[(0.0, 0.0, 0.0), (40.0, 40.0, 40.0)],
        )
        self._check_sample_shape(ds)

    def _check_sample_shape(self, ds: Dataset):
        x1, x2, y, fov1, fov2, delta = ds[0]
        (x1_aug1, x1_aug2, x1_o), (x2_aug1, x2_aug2, x2_o) = x1, x2
        assert x1_aug1.shape == (3, self.IMG_SIZE, self.IMG_SIZE)
        assert x1_aug2.shape == (3, self.IMG_SIZE, self.IMG_SIZE)
        assert x2_aug1.shape == (3, self.IMG_SIZE, self.IMG_SIZE)
        assert x2_aug2.shape == (3, self.IMG_SIZE, self.IMG_SIZE)
        assert x1_o.shape == (3, self.IMG_SIZE, self.IMG_SIZE)
        assert x2_o.shape == (3, self.IMG_SIZE, self.IMG_SIZE)

        assert isinstance(fov1, dict)
        assert isinstance(fov2, dict)
        assert delta.shape == (1,)
