from mmhug.registry import HF_MODELS, MODELS
from mmhug.trainers.base_trainer_model import BaseTrainerModel
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from torchmetrics import Accuracy, AUROC
from transformers import (
    VideoMAEImageProcessor,
    VideoMAEModel,
)


@MODELS.register_module()
class TrainierVideoMAEClassifier(BaseTrainerModel):

    def __init__(
        self,
        backbone: VideoMAEModel,
        task: str = "multiclass",  # multiclass for appearance, multilabel for action and emotion
        num_classes: int = 8,  # 8 for emotion, 40 for appearance, 35 for action
        label_key: str = "emotion",
    ):
        super().__init__()
        self.backbone = HF_MODELS.build(backbone)
        self.fc_norm = (
            nn.LayerNorm(self.backbone.config.hidden_size)
            if self.backbone.config.use_mean_pooling
            else None
        )

        self.num_classes = num_classes
        self.classifier = nn.Linear(self.backbone.config.hidden_size, num_classes)

        self.task = task
        if task == "binary":
            self.loss_fn = BCEWithLogitsLoss()
            self.acc_fn = Accuracy(task=task, num_classes=1)
            self.auc_fn = AUROC(task=task, num_classes=1)
        elif task == "multiclass":
            self.loss_fn = CrossEntropyLoss()
            self.acc_fn = Accuracy(task=task, num_classes=num_classes)
            self.auc_fn = AUROC(task=task, num_classes=num_classes)
        elif task == "multilabel":
            self.loss_fn = BCEWithLogitsLoss()
            self.acc_fn = Accuracy(task="binary", num_classes=1)
            self.auc_fn = AUROC(task="binary", num_classes=1)
        else:
            raise ValueError(f"Unknown task: {task}")

        self.label_key = label_key

    def forward_loss(self, batch):
        y = batch[self.label_key]
        video = batch["video"]
        if isinstance(video, list):
            print([v.shape for v in video])
        feats = self.backbone(video).last_hidden_state[:, 0]
        y_hat = self.classifier(feats)
        if self.task == "multilabel":
            y_hat = y_hat.flatten()
            y = y.flatten()

        loss = self.loss_fn(y_hat, y.float())
        return {"loss": loss}
