from datasets import fashion_mnist, imagenet, medmnist_dataset, cifar
from torch.utils.data import DataLoader
import pytest
import torch


class TestFashionMNIST:
    def test_dataloaders(self):
        dm = fashion_mnist.FashionMNISTDataModule()
        dm.setup("train")

        train_dataloader = dm.train_dataloader()
        val_dataloader = dm.val_dataloader()
        test_dataloader = dm.test_dataloader()

        for i, loader in enumerate([train_dataloader, val_dataloader, test_dataloader]):
            assert len(loader) > 0
            stage = "train" if i == 0 else "eval"
            self.check_batch_size(loader, stage)

    def check_batch_size(self, loader: DataLoader, stage: str):
        batch = next(iter(loader))
        if stage == "train":
            x, y, n = batch
        else:
            x, y = batch

        assert x.shape[0] == y.shape[0]


class TestCIFAR100:
    def test_dataloaders(self):
        dm = cifar.CIFAR100DataModule()
        dm.setup()

        train_dataloader = dm.train_dataloader()
        val_dataloader = dm.val_dataloader()

        for i, loader in enumerate([train_dataloader, val_dataloader]):
            assert len(loader) > 0
            stage = "train" if i == 0 else "eval"
            self.check_batch_size(loader, stage)

    def check_batch_size(self, loader: DataLoader, stage: str):
        batch = next(iter(loader))
        if stage == "train":
            x, y, n = batch
        else:
            x, y = batch

        assert x.shape[0] == y.shape[0]


class TestImageNet:
    def test_dataloaders(self):
        dm = imagenet.ImageNetBlurredDataModule()
        dm.setup()

        train_dataloader = dm.train_dataloader()
        val_dataloader = dm.val_dataloader()

        for i, loader in enumerate([train_dataloader, val_dataloader]):
            assert len(loader) > 0
            stage = "train" if i == 0 else "val"
            self.check_batch_size(loader, stage)

    def check_batch_size(self, loader: DataLoader, stage: str):
        batch = next(iter(loader))
        x, y = batch
        assert x.shape[0] == y.shape[0]


class TestImageNetSimCLR:
    def test_dataloaders(self):
        dm = imagenet.ImageNetBlurredSimCLRDataModule()
        dm.setup()

        train_dataloader = dm.train_dataloader()
        val_dataloader = dm.val_dataloader()

        for i, loader in enumerate([train_dataloader, val_dataloader]):
            assert len(loader) > 0
            stage = "train" if i == 0 else "val"
            self.check_batch_size(loader, stage)

    def check_batch_size(self, loader: DataLoader, stage: str):
        batch = next(iter(loader))
        if stage == "train":
            (x1, x2), y, n = batch
        else:
            x1, y = batch

        assert x1.shape[0] == y.shape[0]
        assert x1.shape[1:] == (3, 224, 224)


class TestMedMNIST:
    def test_diet_dataset(self):
        medmnist_dm = medmnist_dataset.MedMNISTDataModule()
        train_dataset, test_dataset = medmnist_dm.get_medmnist()
        assert len(train_dataset) > 80000
        assert len(test_dataset) > 6000

        x, y = train_dataset[3]
        assert x.shape == (3, 32, 32)
        assert y.shape == torch.Size([])

        x_test, y_test = test_dataset[3]
        assert x_test.shape == (3, 32, 32)
        assert y_test.shape == torch.Size([])

    def test_diet_dataloaders(self):
        dm = medmnist_dataset.MedMNISTDataModule(batch_size=8)
        dm.setup("")
        train_dataloader = dm.train_dataloader()
        train_batch = next(iter(train_dataloader))
        x, y, n = train_batch
        assert x.shape == (8, 3, 32, 32)
        assert y.shape == torch.Size([8])
        assert n.shape == torch.Size([8, 1])

    def test_simclr_dataset(self):
        dm = medmnist_dataset.MedMNISTSimCLRDataModule(batch_size=8)
        dm.setup("")
        assert len(dm.train_dataset) > 80000
        assert len(dm.test_dataset) > 6000

        (x1, x2), y, _ = dm.train_dataset[3]
        assert x1.shape == (3, 32, 32)
        assert x2.shape == (3, 32, 32)
        assert y.shape == torch.Size([])

    def test_simclr_dataloaders(self):
        batch_size = 8
        dm = medmnist_dataset.MedMNISTSimCLRDataModule(batch_size=batch_size)
        dm.setup("")
        train_dataloader = dm.train_dataloader()
        train_batch = next(iter(train_dataloader))
        (x1, x2), y, _ = train_batch
        assert x1.shape == (batch_size, 3, 32, 32)
        assert x2.shape == (batch_size, 3, 32, 32)
        assert y.shape == torch.Size([batch_size])

        val_dataloader = dm.val_dataloader()
        val_batch = next(iter(val_dataloader))
        self.check_lightly_eval_batch(val_batch, batch_size)

    def test_moco_dataloaders(self):
        batch_size = 8
        dm = medmnist_dataset.MedMNISTMoCoDataModule(batch_size=batch_size)
        dm.setup("")

        train_dl = dm.train_dataloader()
        val_dl = dm.val_dataloader()
        test_dl = dm.test_dataloader()

        batch_train = next(iter(train_dl))

        x, y, _ = batch_train
        x_q, x_k = x
        assert y.shape == (batch_size,)
        assert x_q.shape == x_k.shape
        assert x_q.shape == (batch_size, 3, 32, 32)

        batch_val = next(iter(val_dl))
        self.check_lightly_eval_batch(batch_val, batch_size)
        batch_test = next(iter(test_dl))
        self.check_lightly_eval_batch(batch_test, batch_size)

    def check_lightly_eval_batch(self, batch, batch_size):
        # last return value is because we use a lightly dataset for eval
        x, y, _ = batch
        assert y.shape == (batch_size,)
        assert x.shape == (batch_size, 3, 32, 32)


class TestCIFAR10MoCoDataModule:
    def test_dataloader(self):
        batch_size = 8
        dm = cifar.CIFAR10MoCoDataModule(batch_size=batch_size)
        dm.setup()

        train_dl = dm.train_dataloader()
        val_dl = dm.val_dataloader()

        batch_train = next(iter(train_dl))
        batch_val = next(iter(val_dl))

        x, y, _ = batch_train
        x_q, x_k = x
        assert y.shape == (batch_size,)
        assert x_q.shape == x_k.shape
        assert x_q.shape == (batch_size, 3, 32, 32)

        x, y = batch_val
        assert y.shape == (batch_size,)
        assert x.shape == (batch_size, 3, 32, 32)
