import torch.nn as nn
import pytorchvideo.models.resnet
import pytorch_lightning as pl
from typing import List
import torch


class CVLR(pl.LightningModule):
    """CVLR model
    based on ‘Spatiotemporal Contrastive Video Representation Learning’

    Args:
        tau: temperature of InfoNCE loss
    """

    def __init__(
        self,
        num_samples=None,
        tau=0.1,
        n_clips=2,
        learning_rate: float = 0.32,
        dropout=0.5,
    ):
        super().__init__()

        self.model_name = __class__.__name__
        self.tau = tau
        self.learning_rate = learning_rate
        self.dropout = dropout
        self.resnet3d = self.create_resnet()
        self.n_clips_per_video = n_clips

    def create_resnet(self):
        """Creates a modified version of the slow pathway of SlowFast"""
        resnet = pytorchvideo.models.resnet.create_resnet(
            input_channel=3,
            model_depth=50,
            dropout_rate=self.dropout,
            norm=nn.BatchNorm3d,
            activation=nn.ReLU,
            stem_conv_kernel_size=(5, 7, 7),
            stem_conv_stride=(2, 2, 2),
            head=None,
        )
        children = [i for i in resnet.children()][0]
        pooling = nn.AdaptiveAvgPool3d(1)
        modules = children.append(pooling)

        return pytorchvideo.models.net.Net(blocks=modules)

    def training_step(self, batch, batch_idx):
        """
        Args:
            batch (dict): contains keys "video", "video_index", "clip_index".
        """
        # x is of shape (batch_size, 3, frames, H, W)
        x = batch["video"]
        if torch.any(x.isnan()).item():
            self.bad_x = x
            raise ValueError(f"x contains nans. {x.shape=} {batch_idx=}")

        z = self.resnet3d(x)

        video_shots = self.group_shots_by_video(batch["video_index"], z)
        cossim = torch.nn.CosineSimilarity(dim=1)
        shots_1 = video_shots[:, 0, ...]
        shots_2 = video_shots[:, 1, ...]
        positive_pairs = cossim(shots_1, shots_2)
        cossim = torch.nn.CosineSimilarity(dim=2)
        all_pairs = cossim(z[None, :], z[:, None])
        N = video_shots.shape[0]
        all_pairs[range(2 * N), range(2 * N)] = 0
        loss = self.compute_info_nce(positive_pairs, all_pairs, N)

        self.log("train_loss", loss, sync_dist=True)

        return loss

    def compute_info_nce(self, positive_pairs, all_pairs, N):
        """Computes InfoNCE with cosine similarity
        Uses logexpsum trick for numerical stability

        Computed as -1/N sum_over_videos(e^cossim(z_i, z_i') + log(n) - logsumexp_over_j e^cossim(z_i, shot_j))
        """
        loss = 0.0
        positive_pairs = positive_pairs.squeeze(-1).squeeze(-1).squeeze(-1)
        all_pairs = all_pairs.squeeze(-1).squeeze(-1).squeeze(-1)
        for i in range(N):  # loop over N video
            loss += positive_pairs[i] / self.tau
            loss -= torch.logsumexp(all_pairs[i, :] / self.tau, 0)  # divide by 2N terms

        loss = -1 / float(N) * loss

        return loss

    def configure_optimizers(self):
        """
        Setup the Adam optimizer.
        """
        return torch.optim.SGD(self.parameters(), momentum=0.9, lr=self.learning_rate)

    def group_shots_by_video(self, video_index, z):
        # assumes fixed n_clips_per_video
        group = []
        for v in video_index.unique():
            group.append(z[video_index == v][None])
        return torch.cat(group, dim=0)


if __name__ == "__main__":
    model = CVLR()
