import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
import torchmetrics
import pytorchvideo.models.resnet
from models.base_model import BaseModel
from typing import Optional


class VideoClassificationModule(BaseModel):
    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        weight_decay: float = 1e-4,
        momentum: float = 0.9,
        num_classes: int = 400,
        batch_size: int = 8,
        max_epochs: int = 100,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__()

        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum

        # used for cosine annealing
        self.max_epochs = max_epochs
        self.datamodule = datamodule

        self.model = self.create_resnet()

        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()

    def create_resnet(self):
        """Returns slow network from SlowFast paper"""
        return pytorchvideo.models.resnet.create_resnet(
            input_channel=3,
            model_depth=50,
            model_num_class=self.num_classes,
            norm=nn.BatchNorm3d,
            activation=nn.ReLU,
        )

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch, stage: str = "train"):
        # The model expects a video tensor of shape (B, C, T, H, W), which is the
        # format provided by the dataset
        x = batch["video"]
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, batch["label"])
        accuracy_metric = getattr(self, f"{stage}_accuracy")
        accuracy_metric(F.softmax(y_hat, dim=-1), batch["label"])
        self.log(f"{stage}_loss", loss)
        self.log(
            f"{stage}_accuracy",
            accuracy_metric,
            sync_dist=True,
            on_epoch=True,
            on_step=False,
        )
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="val")
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="test")
        return loss

    def test_on_image_step(self, batch, batch_idx):
        """TODO: currently not called. Consider deleting."""
        x, y, _ = batch
        x = x[:, :, None, ...]
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.val_accuracy(F.softmax(y_hat, dim=-1), y)
        self.log("test_loss", loss)
        self.log(
            "test_accuracy",
            acc,
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        return loss

    def configure_optimizers(self):
        """
        Setup the Adam optimizer.
        """
        if self.optimizer == "adam":
            return torch.optim.Adam(
                self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
            )
        return self.sgd()

    def sgd(self):
        """
        We use the SGD optimizer with per step cosine annealing scheduler.
        """
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.learning_rate,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.max_epochs, last_epoch=-1
        )
        return [optimizer], [scheduler]
