from mmhug.models.custom_transformers.wan22_audio.transformer_motion_extractor_mae_v2 import (
    SapiensMotionExtractorV2,
)
from mmhug.registry import HF_MODELS, MODELS
from mmhug.trainers.base_trainer_model import BaseTrainerModel
from torch import nn
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
import torch.nn.functional as F
from einops import rearrange, repeat

@MODELS.register_module()
class TrainerSlipmaeClassifier(BaseTrainerModel):

    def __init__(
        self,
        backbone,
        use_prompt_tokens: bool = True,
        task: str = "multiclass",  # multiclass for appearance, multilabel for action and emotion
        num_classes: int = 8,  # 8 for emotion, 40 for appearance, 35 for action
        label_key: str = "emotion",
    ):
        super().__init__()
        self.backbone: SapiensMotionExtractorV2 = HF_MODELS.build(backbone)
        self.backbone.init_weights()
        self.classifier = nn.Linear(1024, num_classes)

        self.task = task
        if task == "binary":
            self.loss_fn = BCEWithLogitsLoss()
        elif task == "multiclass":
            self.loss_fn = CrossEntropyLoss()
        elif task == "multilabel":
            self.loss_fn = BCEWithLogitsLoss()
        else:
            raise ValueError(f"Unknown task: {task}")

        self.label_key = label_key

        self.use_prompt_tokens = use_prompt_tokens

    def forward_loss(self, batch):
        # b
        y = batch[self.label_key]
        video = batch["video"]
        T = video.shape[1]
        video = rearrange(video, "b t c h w -> (b t) c h w")
        y = repeat(y, "b c -> (b t) c", t=T)

        output = self.backbone(video, do_mask=False)

        if self.use_prompt_tokens:
            # pooling # b n d -> b d
            prompt_tokens = output.prompt_tokens.mean(dim=1)
            y_hat = self.classifier(prompt_tokens)
        else:
            # only use patch tokens
            patch_tokens = output.patch_tokens.mean(dim=1)
            y_hat = self.classifier(patch_tokens)
        if self.task == "multilabel":
            y_hat = y_hat.flatten()
            y = y.flatten()

        loss = self.loss_fn(y_hat, y.float())
        return {"loss": loss}
