from models.lie_ssl.shapes_model import SimCLRFramesMoreParamsModule, SimCLRLieModule
from models.resnet_3d.model import VideoClassificationModule
from models.resnet_2d.model import ImageClassificationModule
from models.resnet_2d.shapes_model import (
    ShapesFineTuneModule,
    ShapesLinearModule,
    ShapesModule,
)
from models.base_model import BaseModel
from models.vivi import vivi
from models.vit import vit_shapes
from models.cnn import cnn
from models.finetuning.linear_finetuner import LinearFineTuner
import pytorch_lightning as pl
from datasets.kinetics import KineticsDataModule
from datasets.mmnist import MMNISTDataModule
import torch
import pytest
from torch.utils.data import DataLoader
from pl_bolts.datamodules import SklearnDataModule
from sklearn.datasets import make_classification
from models.simclr.simclr import SimCLR
from models.simclr import simclr_shapes
from models.mae.mae_shapes import MAEFinetuner, MAELieModule, MAELinearEval
from models.clip import clip_shapes
from tests import dummy_data
import os


@pytest.fixture(scope="module")
def dummy_video_dataloader():
    ds = dummy_data.VideoDataset()
    dl = DataLoader(ds, batch_size=8, num_workers=4)
    return dl


class TestResNet2D:
    def test_instantiation(self):
        resnet_2d = ImageClassificationModule()
        assert resnet_2d.model_name == "ImageClassificationModule"


class TestResNet3D:
    @pytest.mark.slow
    def test_fit(self):
        classification_module = VideoClassificationModule()
        assert isinstance(classification_module, pl.LightningModule)
        data_module = KineticsDataModule()
        trainer = pl.Trainer(fast_dev_run=True)
        trainer.fit(classification_module, data_module)

    def test_dummy_fit(self, dummy_video_dataloader):
        classification_module = VideoClassificationModule()
        assert isinstance(classification_module, pl.LightningModule)
        trainer = pl.Trainer(fast_dev_run=True)
        trainer.fit(classification_module, dummy_video_dataloader)


class TestVIVI:
    @pytest.fixture(scope="class")
    def vivi_model(self):
        return vivi.VIVI()

    @pytest.fixture(scope="class")
    def mmnist_batch(self):
        batch_size = 16
        mmnist_dm = MMNISTDataModule(batch_size=batch_size)
        batch = next(iter(mmnist_dm.train_dataloader()))
        return batch

    def test_resnet50(self):
        """Tests resnet50 inference"""
        resnet50 = vivi.ResNet50()
        x = torch.rand(2, 3, 256, 256)
        pred = resnet50(x)
        assert pred.shape == (2, dummy_data.EMBEDDING_DIM)

    def test_embedding(self, dummy_video_dataloader, vivi_model):
        batch = next(iter(dummy_video_dataloader))
        z = vivi_model.embed_batch(batch["video"])
        assert not torch.any(z.isnan()).item()
        assert z.shape == (
            dummy_video_dataloader.batch_size,
            dummy_data.FRAMES_PER_VIDEO,
            dummy_data.EMBEDDING_DIM,
        )

    def test_shot_embedding(self, dummy_video_dataloader, vivi_model):
        batch = next(iter(dummy_video_dataloader))
        z = vivi_model.embed_batch(batch["video"])
        z_shots = vivi_model.embed_shots(z)
        assert z_shots.shape == (
            dummy_video_dataloader.batch_size,
            dummy_data.EMBEDDING_DIM,
        )

    @pytest.mark.skipif(
        not os.path.exists(f"/datasets01/{dummy_data.USER}/movingmnist"),
        reason="no dataset found",
    )
    def test_shot_embedding_on_mmnist(self, vivi_model, mmnist_batch):
        batch_size = mmnist_batch["video"].shape[0]
        assert mmnist_batch["video"].shape == (batch_size, 3, 8, 224, 224)
        z = vivi_model.embed_batch(mmnist_batch["video"])
        assert z.shape == (batch_size, 8, dummy_data.EMBEDDING_DIM)
        z_shots = vivi_model.embed_shots(z)
        assert z_shots.shape == (batch_size, dummy_data.EMBEDDING_DIM)

    def test_infer_shots_in_first_video(self, vivi_model):
        shots_per_video = 4
        video_index = []
        for i in [13, 66, 17]:
            video_index += [i] * shots_per_video
        assert shots_per_video == vivi_model.infer_shots_in_first_video(video_index)

    @pytest.mark.skipif(
        not os.path.exists(f"/datasets01/{dummy_data.USER}/movingmnist"),
        reason="no dataset found",
    )
    def test_infer_shots_in_first_video_on_mmnist(self, vivi_model, mmnist_batch):
        shots_per_video = vivi_model.infer_shots_in_first_video(
            mmnist_batch["video_index"]
        )
        assert shots_per_video == 2

    def test_group_shots_by_video(self, vivi_model):
        batch_size = 32
        videos_per_batch = 4
        z_shots = torch.rand(batch_size, dummy_data.EMBEDDING_DIM)
        video_shots = vivi_model.group_shots_by_video(z_shots, videos_per_batch)

        shots_per_video = batch_size // videos_per_batch
        assert video_shots.shape == (
            videos_per_batch,
            shots_per_video,
            dummy_data.EMBEDDING_DIM,
        )

    def test_group_shots_by_video_with_uneven_inputs(self, vivi_model):
        batch_size = 32
        videos_per_batch = 3
        z_shots = torch.rand(batch_size, dummy_data.EMBEDDING_DIM)
        video_shots = vivi_model.group_shots_by_video(z_shots, videos_per_batch)

        # ceiling of batch_size / videos_per_batch
        shots_per_video = -(-batch_size // videos_per_batch)

        # -1 to account for uneven chunks
        # when batch_size isn't divisible by videos_per_batch
        assert video_shots.shape == (
            videos_per_batch - 1,
            shots_per_video,
            dummy_data.EMBEDDING_DIM,
        )

    def test_shot_similarity(self, vivi_model):
        """Test critic function from VIVI for computing shot simialrity"""
        shot_shape = (3, 2048)
        shot_1, shot_2 = torch.rand(*shot_shape), torch.rand(*shot_shape)
        g = vivi_model.shot_similarity(shot_1, shot_2)
        assert g.shape == (3, 3)

    def test_info_nce(self, vivi_model):
        videos_per_batch = 4
        pred_next_shots = torch.rand(videos_per_batch, dummy_data.EMBEDDING_DIM)
        next_shots = torch.rand(videos_per_batch, dummy_data.EMBEDDING_DIM)
        assert (
            type(vivi_model.compute_info_nce(pred_next_shots, next_shots).item())
            is float
        )

    def test_compute_frame_loss(self, vivi_model):
        shot_indices = torch.tensor([0, 1, 0, 1])
        video_indices = torch.tensor([0, 0, 1, 1])
        z = torch.rand(len(video_indices), 8, 2048)
        assert vivi_model.compute_frame_loss(z, shot_indices, video_indices) > 0.0

    def test_sample_frames(self, vivi_model):
        clips = torch.rand(16, 2048)
        assert vivi_model._sample_frames(clips, 5).shape == (5, 2048)

    def test_fit(self, dummy_video_dataloader, vivi_model):
        trainer = pl.Trainer(fast_dev_run=True)
        trainer.fit(vivi_model, dummy_video_dataloader)

    def test_forward(self, vivi_model):
        batch_size = 5
        x = torch.rand(batch_size, 3, 8, 224, 224)
        z = vivi_model(x)
        assert z.shape == (batch_size, dummy_data.EMBEDDING_DIM)


class TestLSTM:
    def test_inference(self):
        lstm = vivi.LSTMNextShot(input_size=dummy_data.EMBEDDING_DIM, hidden_size=256)
        num_videos = 8
        previous_shots = torch.rand(num_videos, 3, dummy_data.EMBEDDING_DIM)
        next_shots = torch.rand(num_videos, dummy_data.EMBEDDING_DIM)

        assert lstm(previous_shots).shape == next_shots.shape


class DummySSL(pl.LightningModule):
    def __init__(self, embedding_dim: int = 30):
        super().__init__()
        self.id = torch.nn.Identity()
        self.embedding_dim = embedding_dim

    def forward(self, x):
        return self.id(x)


class TestLinearFineTuner:
    def test_inference(self):
        embedding_dim = 30
        ssl_model = DummySSL(embedding_dim=embedding_dim)
        assert ssl_model.embedding_dim == embedding_dim
        finetuner = LinearFineTuner(
            ssl_model,
            num_classes=10,
        )
        x = torch.rand(8, embedding_dim)
        out = finetuner(x)
        assert out.shape == (8, 10)

    @pytest.mark.slow
    def test_fit(self):
        embedding_dim = 20
        num_classes = 10
        X, y = make_classification(
            n_samples=300,
            n_features=embedding_dim,
            n_classes=num_classes,
            n_informative=embedding_dim - 3,
        )
        data_module = SklearnDataModule(X, y)

        ssl_model = DummySSL(embedding_dim=embedding_dim)
        finetuner = LinearFineTuner(
            ssl_model,
            num_classes=num_classes,
        )
        trainer = pl.Trainer()
        trainer.fit(finetuner, datamodule=data_module)


class TestSimCLR:
    def test_simclr_instantiation(self):
        simclr = SimCLR()
        assert isinstance(simclr, SimCLR)

    @pytest.mark.slow
    def test_simclr_fit(self):
        data_module = dummy_data.SimCLRImageDataModule(batch_size=8)
        simclr = SimCLR()
        trainer = pl.Trainer(fast_dev_run=True)
        trainer.fit(simclr, data_module)


class TestFashionCNN:
    def test_inference(self):
        model = cnn.FashionCNN()
        x_batch = torch.rand(8, 1, 28, 28)
        y_hat = model(x_batch)
        y_hat.shape == (8,)


class TestShapesCNN:
    def test_inference(self):
        model = ShapesModule()
        x = torch.rand(4, 3, 600, 600)
        out = model(x)
        assert out.shape == (4, 1000)

    def test_shapes_linear_module(self):
        model = ShapesLinearModule()
        x = torch.rand(4, 3, 600, 600)
        out = model(x)
        assert out.shape == (4, 15)

    def test_shapes_finetune_module(self):
        model = ShapesFineTuneModule()
        assert model.num_classes == 15
        x = torch.rand(4, 3, 600, 600)
        out = model(x)
        assert out.shape == (4, 15)


class TestSimCLRLinearEvalModule:
    def test_instantiation(self):
        simclr = simclr_shapes.SimCLRLinearEvalModule()
        assert isinstance(simclr, BaseModel)

    def test_inference(self):
        simclr = simclr_shapes.SimCLRLinearEvalModule()
        batch_size = 8
        x = torch.rand(batch_size, 3, 224, 224)
        out = simclr(x)
        assert out.shape == (batch_size, simclr.num_classes)

    def test_train(self):
        num_classes = 52
        batch_size = 8
        dm = dummy_data.ShapesDataModule(
            batch_size=batch_size, num_classes=num_classes
        )
        simclr = simclr_shapes.SimCLRLinearEvalModule(datamodule=dm)
        trainer = pl.Trainer(
            gpus=0, max_epochs=1, limit_train_batches=3, limit_val_batches=3
        )
        trainer.fit(simclr, datamodule=dm)


class TestViTShapesLinearEval:
    def test_inference(self):
        batch_size = 8
        vit = vit_shapes.ViTLinearEval()
        x = torch.rand(batch_size, 3, 224, 224)
        out = vit(x)
        assert out.shape == (batch_size, vit.num_classes)


class TestMAE:
    def test_linear_eval(self):
        batch_size = 8
        mae = MAELinearEval()
        x = torch.rand(batch_size, 3, 224, 224)
        out = mae(x)
        assert out.shape == (batch_size, mae.num_classes)

    def test_finetuning(self):
        batch_size = 8
        mae = MAEFinetuner()
        x = torch.rand(batch_size, 3, 224, 224)
        out = mae(x)
        assert out.shape == (batch_size, mae.num_classes)


class TestCLIPShapesLinearEval:
    def test_inference(self):
        batch_size = 8
        clip = clip_shapes.CLIPLinearEval()
        x = torch.rand(batch_size, 3, 224, 224)
        out = clip(x)
        assert out.shape == (batch_size, clip.num_classes)


class TestSimCLRLieModule:
    def test_inference(self):
        batch_size = 8
        feat_dim = 2048
        num_neighbors = 5

        model = SimCLRLieModule()
        x = torch.rand(batch_size, 3, 224, 224)
        out = model(x, num_neighbors=num_neighbors)
        assert out.shape == (batch_size, num_neighbors + 1, feat_dim)


class TestSimCLRFramesMoreParamsModule:
    def test_inference(self):
        batch_size = 8
        feat_dim = 2048

        model = SimCLRFramesMoreParamsModule()
        x = torch.rand(batch_size, 3, 224, 224)
        out = model(x)
        assert out.shape == (batch_size, 1, feat_dim)


class TestMAELieModule:
    def test_mae_encoder(self):
        batch_size = 8
        feature_dim = 768

        model = MAELieModule()
        x = torch.rand(batch_size, 3, 224, 224)
        z, _, _ = model.mae.forward_encoder(x, mask_ratio=0.0)
        # ViT output should be (batch_size, 197 patches, 768)
        assert z.shape == (batch_size, 197, feature_dim)

        # 75% is default masking ratio used
        z_masked, _, _ = model.mae.forward_encoder(x, mask_ratio=0.75)
        assert z_masked.shape == (batch_size, 50, feature_dim)

    def test_compute_rep(self):
        batch_size = 8
        feature_dim = 768

        model = MAELieModule()
        x = torch.rand(batch_size, 3, 224, 224)
        z = model.compute_rep(x)
        assert z.shape == (batch_size, feature_dim)

    def test_transform(self):
        batch_size = 8
        num_lie_generators = 3

        model = MAELieModule(dim_d=num_lie_generators)

        x = torch.rand(batch_size, 3, 224, 224)
        z = model.compute_rep(x)

        z_transformed = model.transform(z)
        assert z_transformed.shape == z.shape

