from typing import Any, Dict, Tuple, Union
from einops import rearrange
import fire
import numpy as np
import torch
from mmengine.config import Config
from mmengine.device import get_device
from mmengine.runner import load_checkpoint
from transformers import BitImageProcessor
from mmhug.datasets.transforms.keypoint2mask import sapiens2mask
from mmhug.datasets.utils.kpt2face import _fallback_face_indices_ex_ear_308
from mmhug.models.custom_transformers.auto_avsr.datamodule import TextTransform
from mmhug.models.custom_transformers.auto_avsr.espnet.batch_beam_search import (
    BatchBeamSearch,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.ctc import CTC
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.encoder.conformer_encoder import (
    ConformerEncoder,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.scorers.ctc import (
    CTCPrefixScorer,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.scorers.length_bonus import (
    LengthBonus,
)
from mmhug.models.custom_transformers.sapiens.heatmap_head import HeatmapHead
from mmhug.models.custom_transformers.sapiens.vit_sapiens import (
    SapiensVisionTransformer,
)
from mmhug.models.custom_transformers.slipmae.transformer_slipmae_encoder import (
    SlipmaeEncoder,
)
from mmhug.models.custom_transformers.auto_avsr.espnet.pytorch_backend.decoder.transformer_decoder import (
    TransformerDecoder,
)
from mmhug.registry import MODELS
from mmhug.trainers.trainer_slipmae.utils import (
    keepface_mask_prob_from_kpts,
)
from mmhug.trainers.trainer_slipmae_vsr.trainer_slipmae_espnet_vsr import (
    SlipmaeEspnetVSRTrainer,
)
from decord import VideoReader

from torch import nn


class SlipmaeVsrPipeline:
    def __init__(
        self,
        video_processor,
        slipmae_encoder: SlipmaeEncoder,
        proj_encoder: nn.Linear,
        conformer_encoder: ConformerEncoder,
        decoder: TransformerDecoder,
        sapiens: SapiensVisionTransformer,
        heatmap_head: HeatmapHead,
        ctc: CTC,
        eos: int,
        dtype=torch.bfloat16,
    ):
        self.device = get_device()
        self.dtype = dtype
        self.video_processor = video_processor

        self.slipmae_encoder = slipmae_encoder.to(self.device, dtype=self.dtype)
        self.conformer_encoder = conformer_encoder.to(self.device, dtype=self.dtype)
        self.decoder = decoder.to(self.device, dtype=self.dtype)
        self.sapiens = sapiens.to(self.device, dtype=self.dtype)
        self.heatmap_head = heatmap_head.to(self.device, dtype=self.dtype)
        self.proj_encoder = proj_encoder.to(self.device, dtype=self.dtype)
        self.ctc = ctc.to(self.device, dtype=self.dtype)

        self.text_transform = TextTransform()
        self.token_list = self.text_transform.token_list
        self.eos = eos

    def scorers(self):
        return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))

    @torch.no_grad()
    def sapiens_face_det(
        self, video: torch.Tensor, batch_size=8
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        H, W = video.shape[-2:]
        all_keypoints = []
        all_keypoints_scores = []
        for i in range(0, len(video), batch_size):
            video_batch = video[i : i + batch_size]
            _, keypoints, keypoints_scores = self.get_sapiens_heatmap(video_batch)
            all_keypoints.append(torch.from_numpy(keypoints))
            all_keypoints_scores.append(torch.from_numpy(keypoints_scores))
        all_keypoints = torch.cat(all_keypoints, dim=0)
        all_keypoints_scores = torch.cat(all_keypoints_scores, dim=0)
        keypoints = torch.cat(
            [all_keypoints, all_keypoints_scores.unsqueeze(-1)], dim=-1
        )
        lower_face_mask = sapiens2mask(
            H, W, keypoints, mask_area="lower_face", mask_expand=(0, 0, 0, 20)
        )
        return keypoints, lower_face_mask

    @torch.no_grad()
    def get_sapiens_heatmap(
        self,
        imgs: torch.Tensor,
    ) -> torch.Tensor:

        feats = self.sapiens(imgs, out_type="featmap").last_hidden_state
        # 256 192
        heatmap, pred = self.heatmap_head(feats, decode_kpt=True)

        keypoints = np.concatenate([p.keypoints for p in pred], axis=0)
        keypoints_scores = np.concatenate([p.keypoint_scores for p in pred], axis=0)

        heatmap = heatmap.float().detach().cpu().numpy()
        keypoints = keypoints
        keypoints_scores = keypoints_scores
        return heatmap, keypoints, keypoints_scores

    def load_video(self, video: Union[str, np.ndarray]):
        if isinstance(video, str):
            vr = VideoReader(video)
            video = vr.get_batch(range(len(vr))).asnumpy()
        assert isinstance(video, np.ndarray)
        video = self.video_processor(video).pixel_values
        video = torch.tensor(video)
        return video

    @torch.no_grad()
    def encode_visual(self, video, kpt, batch_size=32):
        B = video.shape[0]
        video = rearrange(video, "b c t h w -> (b t) c h w")
        if kpt is not None:
            kpt = rearrange(kpt, "b t k c -> (b t) k c")
        all_vocal_feature = []
        for i in range(0, video.shape[0], batch_size):
            video_batch = video[i : i + batch_size]
            kpt_batch = kpt[i : i + batch_size]

            keepface_mask_prob = keepface_mask_prob_from_kpts(
                kpt_batch[:, _fallback_face_indices_ex_ear_308()],
                img_hw=video_batch.shape[-2:],
            )

            encoder_output = self.slipmae_encoder(
                video_batch, do_mask=True, mask_prob=keepface_mask_prob
            )
            # (b t) c
            vocal_motion = encoder_output["prompt_tokens"][:, -1]
            all_vocal_feature.append(vocal_motion)
        # (b t) c
        all_vocal_feature = torch.cat(all_vocal_feature, dim=0)
        all_vocal_feature = rearrange(all_vocal_feature, "(b t) c -> b t c", b=B)
        return all_vocal_feature

    def __call__(self, video: Union[str, np.ndarray, torch.Tensor]) -> Any:
        beam_search = get_beam_search_decoder(self.scorers(), self.token_list)
        video = self.load_video(video).to(self.device, dtype=self.dtype)
        keypoint, _ = self.sapiens_face_det(video)
        keypoint = keypoint.to(self.device, dtype=self.dtype).unsqueeze(0)
        video = rearrange(video, "t c h w -> 1 c t h w")
        feats = self.encode_visual(video, keypoint)
        feats = self.proj_encoder(feats)
        feats, _ = self.conformer_encoder(feats, None)

        nbest_hyps = beam_search(feats.squeeze(0))
        nbest_hyps = [h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), 1)]]
        predicted_token_id = torch.tensor(list(map(int, nbest_hyps[0]["yseq"][1:])))
        predicted = self.text_transform.post_process(predicted_token_id).replace(
            "<eos>", ""
        )
        return predicted


def get_beam_search_decoder(
    scorers,
    token_list,
    penalty=0,
    ctc_weight=0.1,
    lm_weight=0.0,
    beam_size=40,
):
    odim = len(token_list)
    sos = odim - 1
    eos = odim - 1

    scorers["lm"] = None
    scorers["length_bonus"] = LengthBonus(len(token_list))
    weights = {
        "decoder": 1.0 - ctc_weight,
        "ctc": ctc_weight,
        "lm": lm_weight,
        "length_bonus": penalty,
    }

    return BatchBeamSearch(
        beam_size=beam_size,
        vocab_size=len(token_list),
        weights=weights,
        scorers=scorers,
        sos=sos,
        eos=eos,
        token_list=token_list,
        pre_beam_score_key=None if ctc_weight == 1.0 else "decoder",
    )


def infer(
    cfg: str = "configs/train/espnet_slipmae/espnet_slipmae.py",
    checkpoint: str = "work_dirs/espnet_slipmae/iter_25000.pth",
    video_path="data/vfhq/video_resampled/Clip+zXbxb4m_k9U+P0+C0+F3249-3415.mp4",
    dtype=torch.bfloat16,
):
    # sapiens
    sapiens = SapiensVisionTransformer(
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(1024, 768),
        patch_size=16,
        in_channels=3,
        out_indices=-1,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        final_norm=True,
        with_cls_token=False,
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        patch_cfg=dict(padding=2),
        layer_cfgs=dict(),
        pre_norm=False,
        out_type="featmap",
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    )
    sapiens.init_weights()

    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(512, 512), heatmap_size=(192, 256), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    )
    heatmap_head.init_weights()

    input_video_processor = BitImageProcessor(
        do_resize=True,
        size={"shortest_edge": 512},
        do_center_crop=True,
        crop_size={"height": 512, "width": 512},
        do_rescale=True,
        do_normalize=True,
        image_mean=[0.5, 0.5, 0.5],
        image_std=[0.5, 0.5, 0.5],
    )

    cfg = Config.fromfile(cfg)
    model: SlipmaeEspnetVSRTrainer = MODELS.build(cfg.model)
    load_checkpoint(model, checkpoint, map_location="cpu")

    pipeline = SlipmaeVsrPipeline(
        video_processor=input_video_processor,
        proj_encoder=model.proj_encoder,
        slipmae_encoder=model.slipmae_encoder,
        conformer_encoder=model.encoder,
        decoder=model.decoder,
        sapiens=sapiens,
        heatmap_head=heatmap_head,
        ctc=model.ctc,
        eos=model.eos,
        dtype=dtype,
    )

    transcript = pipeline(video_path)
    print(video_path)
    print(transcript)


if __name__ == "__main__":
    fire.Fire(infer)
