import torch
from torch import Tensor
import torchvision
import pytorch_lightning as pl
from typing import List, Optional
from models.base_model import BaseModel


class VIVI(BaseModel):
    """ViVi model
    based on ‘Self-Supervised Learning of Video-Induced Visual Invariances’

    Args:
        alpha: margin parameter for triplet loss
        omega: scalar weight applied to object loss
        beta: scalar weight for shot loss
    """

    def __init__(
        self,
        omega: float = 5.0,
        alpha: float = 0.5,
        beta: float = 0.04,
        learning_rate: float = 1e-3,
        num_samples: Optional[int] = None,
        semi_hard_negatives: bool = False,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__()

        self.alpha = alpha
        self.omega = omega
        self.beta = beta
        self.learning_rate = learning_rate
        self.num_samples = num_samples
        self.semi_hard_negatives = semi_hard_negatives
        self.datamodule = datamodule

        self.resnet50 = ResNet50(pretrained=False)
        # based on embedding generated by the resnet50
        self.embedding_dim = 2048
        self.lstm = LSTMNextShot(input_size=self.embedding_dim, hidden_size=256)

        # shot similarity
        self.linear1 = torch.nn.Linear(self.embedding_dim, 256)
        self.linear2 = torch.nn.Linear(self.embedding_dim, 256)

        # logging
        self.n_frame_loss_skipped = 0

    def forward(self, x):
        """Outputs average frame embeddings to use for downstream tasks"""
        # frame_embeddings shape is (batch_size, frame, embedding dim)
        frame_embeddings = self.embed_batch(x)
        return torch.mean(frame_embeddings, dim=1)

    def training_step(self, batch, batch_idx):
        """
        Args:
            batch (dict): contains keys "video", "video_index", "clip_index".
        """
        # x is of shape (batch_size, 3, frames, H, W)
        x = batch["video"]
        if torch.any(x.isnan()).item():
            self.bad_x = x
            raise ValueError(f"x contains nans. {x.shape=} {batch_idx=}")

        z = self.embed_batch(x)
        if torch.any(z.isnan()).item():
            self.bad_x = x
            self.bad_z = z
            raise ValueError(f"z contains nans. {x.shape=} {z.shape=} {batch_idx=}")

        z_shots = self.embed_shots(z)

        object_loss = self.compute_object_loss(z)
        frame_loss = self.compute_frame_loss(
            z, batch["shot_index"], batch["video_index"]
        )
        shot_loss = self.compute_shot_loss(
            z_shots,
            batch["video_index"],
        )

        loss = self.omega * object_loss + frame_loss + self.beta * shot_loss

        self.log("train_loss", loss, sync_dist=True)
        self.log("train_object_loss", object_loss, sync_dist=True)
        self.log("train_frame_loss", frame_loss, sync_dist=True)
        self.log("train_shot_loss", shot_loss, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        """
        Args:
            batch (dict): contains keys "video", "video_index", "clip_index".
        """
        # x is of shape (batch_size, 3, frames, H, W)
        x = batch["video"]
        z = self.embed_batch(x)
        z_shots = self.embed_shots(z)

        object_loss = self.compute_object_loss(z)
        frame_loss = self.compute_frame_loss(
            z, batch["shot_index"], batch["video_index"]
        )
        shot_loss = self.compute_shot_loss(
            z_shots,
            batch["video_index"],
        )

        loss = self.omega * object_loss + frame_loss + self.beta * shot_loss

        self.log("val_loss", loss, sync_dist=True)
        self.log("val_object_loss", object_loss, sync_dist=True)
        self.log("val_frame_loss", frame_loss, sync_dist=True)
        self.log("val_shot_loss", shot_loss, sync_dist=True)

        return loss

    def embed_batch(self, x: torch.Tensor) -> torch.Tensor:
        """
        Embed video frames using a ResNet50.
        Note VIVI paper uses ResNet-50v2

        Args:
            x: tensor of shape (batch_size, 3, frames, H, W)

        Returns: tensor of shape (batch_size, frames, frame embedding dim)
        """
        z = []
        for shot in x:
            z_frames = self.resnet50(shot.permute(1, 0, 2, 3).squeeze())
            z.append(z_frames.unsqueeze(0))
        return torch.cat(z)

    def embed_shots(self, z: torch.Tensor) -> torch.Tensor:
        """
        Parameter-free average of frame embeddings in a shot.
        Each sample in batch is assumed to come from a different clip.
        Each clip defines a shot.

        Args:
            z: embeddings of shape (batch_size, frames, embedding dim)

        Returns: tensor of shape (batch_size, embedding_dim)
        """
        return torch.mean(z, dim=1)

    def compute_object_loss(self, z: torch.Tensor):
        """Compute triplet loss where frames from same shot are positives"""
        return 0.0

    def compute_frame_loss(
        self, z: torch.Tensor, shot_indices: torch.Tensor, video_indices: torch.Tensor
    ) -> float:
        """Compute triplet loss where frames from the same shot are positives.
        Frames from other shots (but same video) are negatives,
        aka "semi-hard" negatives.

        Args:
            z: embeddings of shape (batch_size, frames, embedding dim)
            shot_indices: 1-d array of length batch_size indicating the
                shot index of each clip
            video_indices: 1-d array of length batch_size indicating the
                video index of each clip

        Returns: sum of triplet loss for each frame based on VIVI-object paper.
        """
        triplet_loss = 0.0

        for shot_index, video_index, shot in zip(shot_indices, video_indices, z):
            negative_candidates = self.mine_negatives(
                shot_index, video_index, shot_indices, video_indices, z
            )
            frames_per_shot = shot.shape[0]
            negatives = self._sample_frames(negative_candidates, frames_per_shot)

            if negative_candidates.shape[0] < frames_per_shot:
                self.n_frame_loss_skipped += 1
                self.log(
                    "n_frame_loss_skipped", self.n_frame_loss_skipped, sync_dist=True
                )
                continue

            perm = torch.randperm(frames_per_shot)
            positives = shot[perm]

            positives_norm = torch.square((shot - positives).norm(dim=1, p=2))
            negatives_norm = torch.square((shot - negatives).norm(dim=1, p=2))
            zeros = torch.zeros(positives_norm.shape[0], device=self.device)

            triplet_loss += torch.mean(
                torch.max((positives_norm - negatives_norm) + self.alpha, zeros)
            )

        return triplet_loss

    def mine_negatives(
        self,
        shot_index: int,
        video_index: int,
        shot_indices,
        video_indices: List,
        z: Tensor,
    ) -> Tensor:
        """Returns a tensor of negative shot candidates.
        Semi-hard negatives are mined from the same video.

        Args:
            z: embeddings of shape (batch_size, frames, embedding dim)
        Returns: tensor of shape (num_frames, embedding dim)
        """
        if self.semi_hard_negatives:
            # semi-hard negatives are from the same video, but a different shot
            is_same_video = (shot_index != shot_indices) & (
                video_index == video_indices
            )
            mask = is_same_video
        else:
            is_not_same_video = video_index != video_indices
            mask = is_not_same_video
        return z[mask].flatten(end_dim=1)

    def _sample_frames(self, shots: torch.Tensor, n: int) -> torch.Tensor:
        """Randomly samples specified number of frames.

        Args:
            shots: tensor of shape (num of frames, embedding dim)
            n: number of frames to sample

        Returns: tensor of shape (n, embedding dim)
        """
        perm = torch.randperm(shots.shape[0])
        return shots[perm][:n, :]

    def compute_shot_loss(
        self, z_shots: torch.Tensor, video_index: torch.Tensor
    ) -> float:
        """InfoNCE on prediction of the last shot in each video.
        Prediction is generated by an LSTM.
        Similarity between shots is computed using a critic function g
            parameterized by two MLPs.

        Args:
            z_shots: embedding for shots of size (batch_size, embedding dim)
            video_index: 1-D tensor of shape (batch_size)

        """
        batch_size = z_shots.shape[0]
        videos_per_batch = batch_size // self.infer_shots_in_first_video(video_index)
        # video_shots (videos_per_batch, shots_per_video, embedding dim)
        video_shots = self.group_shots_by_video(
            z_shots, videos_per_batch
        )  # This always assume that the shots are ordered and following one another, I find it a bit dangerous and I'd prefer we use the video_index, what do you think?
        info_nce = self.compute_next_shot_prediction_loss(video_shots)
        return info_nce

    def group_shots_by_video(
        self, z_shots: torch.Tensor, videos_per_batch: int
    ) -> torch.Tensor:
        """Groups shots by video into even chunks by
        dropping the last set of shots if need be.

        Args:
            z_shots: embedding for shots of size (batch_size, embedding dim)

        Returns: tensor (videos_per_batch, shots_per_video, embedding_dim)
        """
        chunks = torch.chunk(z_shots, videos_per_batch)

        # drop last chunk if not divisible by videos_per_batch
        even_chunks = chunks
        is_batch_size_divisible = (z_shots.shape[0] % videos_per_batch) == 0
        if not is_batch_size_divisible:
            even_chunks = chunks[:-1]
        return torch.stack(even_chunks)

    def compute_next_shot_prediction_loss(self, shots: torch.Tensor):
        """Computes shot loss by splitting shots into all possible prev next splits.
        Args:
            shots: shape is (video_per_batch, shots_per_video, embeddign dim)

        Returns: sum of info nce losses for predicted next shots
        """
        loss = 0.0
        shots_per_video = shots.shape[1]
        for idx in range(1, shots_per_video):
            next_shot = shots[:, idx, :]
            prev_shots = shots[:, :idx, :]
            pred_next_shot = self.lstm(prev_shots)
            loss += self.compute_info_nce(pred_next_shot, next_shot)
        return loss

    def compute_info_nce(
        self, pred_next_shots: torch.Tensor, next_shots: torch.Tensor
    ) -> float:
        """Computes InfoNCE with similarity from a bi-linear function g based on VIVI.
        Uses logexpsum trick for numerical stability

        Computed as -1/N sum_over_videos(g(pred_i, shot_i) + log(n) - logsumexp_over_j(pred_i, shot_j))
        """
        videos_per_batch = pred_next_shots.shape[0]
        # (video_per_batch, videos_per_batch)
        similarities_matrix = self.shot_similarity(pred_next_shots, next_shots)
        loss = 0.0

        for i in range(videos_per_batch):
            loss += similarities_matrix[i][i]
            loss += torch.log(torch.tensor(videos_per_batch))
            loss -= torch.logsumexp(similarities_matrix[i, :], 0)

        loss = -1 / float(videos_per_batch) * loss

        return loss

    def shot_similarity(self, shot_1, shot_2) -> torch.Tensor:
        """Computes the similarity of two shots using bi-linear form from VIVI.
        Returns: similarity scores of shape (batch_size)
        """
        return torch.mm(self.linear1(shot_1), self.linear2(shot_2).T)

    def infer_shots_in_first_video(self, video_index: torch.Tensor) -> int:
        """Infers the number of shots per video from the video_index list"""
        # TODO: consider adding shot_per_video as property from the data during initialization
        start_index = video_index[0]
        shots_count = 0
        for i in video_index:
            if i == start_index:
                shots_count += 1
            else:
                return shots_count

    def configure_optimizers(self):
        """
        Setup the Adam optimizer.
        """
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


class ResNet50(pl.LightningModule):
    def __init__(
        self,
        pretrained: bool = True,
        num_classes: int = 1000,
    ):
        super().__init__()
        self.pretrained = pretrained
        self.num_classes = num_classes
        self.model = self.get_model_without_last_layer(
            torchvision.models.resnet50(pretrained=pretrained, num_classes=num_classes)
        )

    def get_model_without_last_layer(self, model: torchvision.models.resnet.ResNet):
        embedding_model = torch.nn.Sequential(*(list(model.children())[:-1]))
        return embedding_model

    def forward(self, x):
        """Produces embedding of size 2048"""
        return self.model(x).squeeze()


class LSTMNextShot(pl.LightningModule):
    def __init__(
        self,
        input_size: int = 2048,
        hidden_size: int = 256,
    ):
        """LSTM for predicting the next shot from the previous shots."""
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.num_layers = 1
        self.lstm = torch.nn.LSTM(
            input_size, hidden_size, self.num_layers, batch_first=True
        )
        self.linear = torch.nn.Linear(hidden_size, input_size)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        y_pred = self.linear(lstm_out[:, -1])
        return y_pred
