from typing import Dict, List, Sequence, Union
from einops import rearrange
from torch import Tensor
import torch
from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308
from mmhug.models.custom_transformers.auto_avsr.datamodule import TextTransform
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.ctc import CTC
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.decoder.transformer_decoder import (
    TransformerDecoder,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.encoder.conformer_encoder import (
    ConformerEncoder,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.transformer.add_sos_eos import (
    add_sos_eos,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.transformer.label_smoothing_loss import (
    LabelSmoothingLoss,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.transformer.mask import (
    target_mask,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.scorers.ctc import (
    CTCPrefixScorer,
)
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)
from mmhug.models.custom_transformers.slipmae.transformer_slipmae_encoder import (
    SlipmaeEncoder,
)
from mmhug.registry import HF_MODELS, MODELS
from mmhug.trainers.base_trainer_model import BaseTrainerModel
from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_heatmap,
    keepface_mask_prob_from_kpts,
)


def pad_to_max_length(
    seqs: Sequence[Union[List[int], torch.Tensor]],
    pad_value: int,
    dtype: torch.dtype = torch.long,
    batch_first: bool = True,
    max_len: int = None,  # 可选：指定固定长度；若 None 则取批内最大
) -> torch.Tensor:
    if len(seqs) == 0:
        # 空批；返回形状 [0, 0] 的占位张量
        shape = (0, 0) if batch_first else (0, 0)
        return torch.empty(shape, dtype=dtype)

    # 统一转为 1D LongTensor
    tensors: List[torch.Tensor] = []
    for s in seqs:
        if isinstance(s, torch.Tensor):
            t = s.to(dtype=dtype).flatten()
        else:
            t = torch.tensor(s, dtype=dtype)
        tensors.append(t)

    lengths = [int(t.numel()) for t in tensors]
    T = max_len if max_len is not None else (max(lengths) if lengths else 0)
    B = len(tensors)

    if batch_first:
        out = torch.full((B, T), pad_value, dtype=dtype)
        for i, t in enumerate(tensors):
            L = min(t.numel(), T)
            if L > 0:
                out[i, :L] = t[:L]
    else:
        out = torch.full((T, B), pad_value, dtype=dtype)
        for i, t in enumerate(tensors):
            L = min(t.numel(), T)
            if L > 0:
                out[:L, i] = t[:L]

    return out


@MODELS.register_module()
class SlipmaeEspnetVSRTrainer(BaseTrainerModel):
    def train(self, mode: bool = True):
        super().train(mode)
        if self.train_visual_encoder:
            return
        else:
            if mode:
                self.eval()
                self.requires_grad_(False)

                self.encoder.train()
                self.requires_grad_(True)

                self.decoder.train()
                self.decoder.requires_grad_(True)

                self.proj_encoder.train()
                self.proj_encoder.requires_grad_(True)

                self.ctc.train()
                self.ctc.requires_grad_(True)

    def __init__(
        self,
        slipmae_encoder,
        train_visual_encoder: bool = False,
        ctc_weight=0.1,
        ignore_id=-1,
        # keypoint detection
        raw_sapiens: SapiensVisionTransformer = None,
        heatmap_head: HeatmapHead = None,
        init_cfg=None,
    ):
        super().__init__(init_cfg=init_cfg)
        self.text_transform = TextTransform()
        self.token_list = self.text_transform.token_list
        odim = len(self.token_list)
        self.train_visual_encoder = train_visual_encoder
        self.slipmae_encoder: SlipmaeEncoder = HF_MODELS.build(slipmae_encoder)
        self.encoder = ConformerEncoder(
            attention_dim=768,
            attention_heads=12,
            linear_units=3072,
            num_blocks=12,
            cnn_module_kernel=31,
        )

        self.decoder = TransformerDecoder(
            odim=odim,
            attention_dim=768,
            attention_heads=12,
            linear_units=3072,
            num_blocks=6,
        )

        self.proj_encoder = torch.nn.Linear(1024, 768)

        self.blank = 0
        self.sos = odim - 1
        self.eos = odim - 1
        self.odim = odim
        self.ignore_id = ignore_id

        # loss
        self.ctc_weight = ctc_weight
        self.ctc = CTC(odim, 768, 0.1, reduce="mean")
        self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, 0.1, False)

        # for face detection
        if raw_sapiens is not None:
            assert heatmap_head is not None
            self.raw_sapiens: SapiensVisionTransformer = HF_MODELS.build(raw_sapiens)
            self.heatmap_head: HeatmapHead = HF_MODELS.build(heatmap_head)
            self.realtime_pose_estimate = True

        self.train()

    def scorers(self):
        return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))

    def forward_loss(self, batch: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:

        video = batch["video"]
        video = rearrange(video, "b t c h w -> b c t h w")
        keypoint = batch.get("keypoint", None)

        # gt process
        text = batch["speech_script"]
        labels = [self.text_transform.tokenize(t) for t in text]
        lengths = torch.tensor((video.shape[2],), device=video.device)
        labels = pad_to_max_length(labels, self.ignore_id).to(video.device)

        # pred from video
        feats = self.encode_visual(video, keypoint)
        feats = self.proj_encoder(feats)
        feats, _ = self.encoder(feats, None)
        try:
            loss_ctc, _ = self.ctc(feats, lengths, labels)
        except Exception as e:
            print(e)
            print(text)
            print(self.text_transform.post_process(labels))
            print(batch["video_metadata"])

        # decoder loss
        ys_in_pad, ys_out_pad = add_sos_eos(labels, self.sos, self.eos, self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, _ = self.decoder(ys_in_pad, ys_mask, feats, None)
        loss_att = self.criterion(pred_pad, ys_out_pad)

        return {
            "ctc_loss": self.ctc_weight * loss_ctc,
            "att_loss": (1 - self.ctc_weight) * loss_att,
        }

    @torch.no_grad()
    def encode_visual(self, video, kpt=None, batch_size=32):
        B = video.shape[0]
        video = rearrange(video, "b c t h w -> (b t) c h w")
        if kpt is not None:
            kpt = rearrange(kpt, "b t k c -> (b t) k c")
        all_vocal_feature = []
        for i in range(0, video.shape[0], batch_size):
            video_batch = video[i : i + batch_size]
            if kpt is not None:
                kpt_batch = kpt[i : i + batch_size].to(video.dtype)

                keepface_mask_prob = keepface_mask_prob_from_kpts(
                    kpt_batch[:, _fallback_face_indices_ex_ear_308()],
                    img_hw=video_batch.shape[-2:],
                )
            else:
                heatmap_batch = self.get_sapiens_heatmap(video_batch)
                keepface_mask_prob = keepface_mask_prob_from_heatmap(
                    heatmap_batch[:, _fallback_face_indices_ex_ear_308()]
                )

            encoder_output = self.slipmae_encoder(
                video_batch, do_mask=True, mask_prob=keepface_mask_prob
            )
            # (b t) c
            vocal_motion = encoder_output["prompt_tokens"][:, -1]
            all_vocal_feature.append(vocal_motion)
        # (b t) c
        all_vocal_feature = torch.cat(all_vocal_feature, dim=0)
        all_vocal_feature = rearrange(all_vocal_feature, "(b t) c -> b t c", b=B)
        return all_vocal_feature

    @torch.no_grad()
    def get_sapiens_heatmap(self, imgs: Tensor, max_batch_size: int = 32) -> Tensor:
        """
        Get sapiens heatmap from imgs.
        Args:
            imgs (Tensor):
                A batch of imgs: [B, C, H, W]
        Returns:
            heatmap (Tensor):
                A batch of heatmap: [B, N, H, W]
        """
        # [B, C, H, W]
        if imgs.shape[0] > max_batch_size:
            heatmap = []
            for i in range(0, imgs.shape[0], max_batch_size):
                heatmap.append(self.get_sapiens_heatmap(imgs[i : i + max_batch_size]))
            heatmap = torch.cat(heatmap, dim=0)
            return heatmap
        feats = self.raw_sapiens(imgs, out_type="featmap").last_hidden_state
        heatmap, _ = self.heatmap_head(feats, decode_kpt=False)
        return heatmap
