import torch
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import SapiensVisionTransformer
from mmhug.models.custom_transformers.wan22_audio.transformer_motion_extractor_mae_v2 import (
    SapiensMotionExtractorV2,
)
from mmhug.registry import HF_MODELS, MODELS
from mmhug.trainers.base_trainer_model import BaseTrainerModel
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
import torch.nn.functional as F
from einops import rearrange, repeat

@MODELS.register_module()
class TrainerSlipmaeFaceOnlyClassifier(BaseTrainerModel):

    def train(self, mode: bool = True):
        super().train(mode)
        self.raw_sapiens.eval()
        self.raw_sapiens.requires_grad_(False)
        self.heatmap_head.eval()
        self.heatmap_head.requires_grad_(False)

    def __init__(
        self,
        backbone,
        raw_sapiens: SapiensVisionTransformer,
        heatmap_head: HeatmapHead,
        use_prompt_tokens: bool = True,
        task: str = "multiclass",  # multiclass for appearance, multilabel for action and emotion
        num_classes: int = 8,  # 8 for emotion, 40 for appearance, 35 for action
        label_key: str = "emotion",
    ):
        super().__init__()
        self.raw_sapiens = HF_MODELS.build(raw_sapiens)
        self.raw_sapiens.init_weights()
        self.raw_sapiens.eval()
        self.raw_sapiens.requires_grad_(False)

        self.heatmap_head = HF_MODELS.build(heatmap_head)
        self.heatmap_head.init_weights()
        self.heatmap_head.eval()
        self.heatmap_head.requires_grad_(False)

        self.backbone: SapiensMotionExtractorV2 = HF_MODELS.build(backbone)
        self.backbone.init_weights()
        self.classifier = nn.Linear(1024, num_classes)

        self.task = task
        if task == "binary":
            self.loss_fn = BCEWithLogitsLoss()
        elif task == "multiclass":
            self.loss_fn = CrossEntropyLoss()
        elif task == "multilabel":
            self.loss_fn = BCEWithLogitsLoss()
        else:
            raise ValueError(f"Unknown task: {task}")

        self.label_key = label_key

        self.use_prompt_tokens = use_prompt_tokens

    @torch.no_grad()
    def get_sapiens_heatmap(self, imgs: Tensor, max_batch_size: int = 32) -> Tensor:
        """
        Get sapiens heatmap from imgs.
        Args:
            imgs (Tensor):
                A batch of imgs: [B, C, H, W]
        Returns:
            heatmap (Tensor):
                A batch of heatmap: [B, N, H, W]
        """
        # [B, C, H, W]
        if imgs.shape[0] > max_batch_size:
            heatmap = []
            for i in range(0, imgs.shape[0], max_batch_size):
                heatmap.append(self.get_sapiens_heatmap(imgs[i : i + max_batch_size]))
            heatmap = torch.cat(heatmap, dim=0)
            return heatmap
        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        heatmap, _ = self.heatmap_head(feats, decode_kpt=False)
        return heatmap
    
    def forward_loss(self, batch):
        # b
        y = batch[self.label_key]
        video = batch["video"]
        T = video.shape[1]
        video = rearrange(video, "b t c h w -> (b t) c h w")
        y = repeat(y, "b c -> (b t) c", t=T)

        heatmap = self.get_sapiens_heatmap(video)
        keepface_mask_prob = (
            1
            - torch.max(
                F.interpolate(
                    heatmap.float(), (32, 32), mode="bilinear", align_corners=False
                ).to(heatmap.dtype),
                dim=1,
            ).values
        ).view(heatmap.shape[0], -1)

        output = self.backbone(video, do_mask=True, mask_prob=keepface_mask_prob)

        if self.use_prompt_tokens:
            # pooling # b n d -> b d
            prompt_tokens = output.prompt_tokens.mean(dim=1)
            y_hat = self.classifier(prompt_tokens)
        else:
            # only use patch tokens
            patch_tokens = output.patch_tokens.mean(dim=1)
            y_hat = self.classifier(patch_tokens)

        if self.task == "multilabel":
            y_hat = y_hat.flatten()
            y = y.flatten()

        loss = self.loss_fn(y_hat, y.float())
        return {"loss": loss}
