# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence, Tuple, Union
from mmcv.cnn import build_conv_layer, build_upsample_layer
from torch import Tensor, nn
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch.nn import functional as F

from mmhug.utils.vis_utils import draw_keypoints_sequence

OptIntSeq = Optional[Sequence[int]]

from mmhug.registry import HF_MODELS
from mmhug.utils.tensor_utils import to_numpy


@HF_MODELS.register_module(force=True)
class HeatmapHead(BaseModule):
    """Top-down heatmap head introduced in `Simple Baselines`_ by Xiao et al
    (2018). The head is composed of a few deconvolutional layers followed by a
    convolutional layer to generate heatmaps from low-resolution feature maps.

    Args:
        in_channels (int | Sequence[int]): Number of channels in the input
            feature map
        out_channels (int): Number of channels in the output heatmap
        deconv_out_channels (Sequence[int], optional): The output channel
            number of each deconv layer. Defaults to ``(256, 256, 256)``
        deconv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
            of each deconv layer. Each element should be either an integer for
            both height and width dimensions, or a tuple of two integers for
            the height and the width dimension respectively.Defaults to
            ``(4, 4, 4)``
        conv_out_channels (Sequence[int], optional): The output channel number
            of each intermediate conv layer. ``None`` means no intermediate
            conv layer between deconv layers and the final conv layer.
            Defaults to ``None``
        conv_kernel_sizes (Sequence[int | tuple], optional): The kernel size
            of each intermediate conv layer. Defaults to ``None``
        final_layer (dict): Arguments of the final Conv2d layer.
            Defaults to ``dict(kernel_size=1)``
        loss (Config): Config of the keypoint loss. Defaults to use
            :class:`KeypointMSELoss`
        decoder (Config, optional): The decoder config that controls decoding
            keypoint coordinates from the network output. Defaults to ``None``
        init_cfg (Config, optional): Config to control the initialization. See
            :attr:`default_init_cfg` for default settings
        extra (dict, optional): Extra configurations.
            Defaults to ``None``

    .. _`Simple Baselines`: https://arxiv.org/abs/1804.06208
    """

    _version = 2

    def __init__(
        self,
        in_channels: Union[int, Sequence[int]],
        out_channels: int,
        deconv_out_channels: OptIntSeq = (256, 256, 256),
        deconv_kernel_sizes: OptIntSeq = (4, 4, 4),
        conv_out_channels: OptIntSeq = None,
        conv_kernel_sizes: OptIntSeq = None,
        final_layer: dict = dict(kernel_size=1),
        use_silu: bool = True,
        decoder=dict(type="UDPHeatmap"),
        init_cfg=None,
    ):

        super().__init__(init_cfg=init_cfg)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.decoder = HF_MODELS.build(decoder)

        self.use_silu = use_silu  ## instance norm + silu instead of batchnorm + relu

        if deconv_out_channels:
            if deconv_kernel_sizes is None or len(deconv_out_channels) != len(
                deconv_kernel_sizes
            ):
                raise ValueError(
                    '"deconv_out_channels" and "deconv_kernel_sizes" should '
                    "be integer sequences with the same length. Got "
                    f"mismatched lengths {deconv_out_channels} and "
                    f"{deconv_kernel_sizes}"
                )

            self.deconv_layers = self._make_deconv_layers(
                in_channels=in_channels,
                layer_out_channels=deconv_out_channels,
                layer_kernel_sizes=deconv_kernel_sizes,
            )
            in_channels = deconv_out_channels[-1]
        else:
            self.deconv_layers = nn.Identity()

        if conv_out_channels:
            if conv_kernel_sizes is None or len(conv_out_channels) != len(
                conv_kernel_sizes
            ):
                raise ValueError(
                    '"conv_out_channels" and "conv_kernel_sizes" should '
                    "be integer sequences with the same length. Got "
                    f"mismatched lengths {conv_out_channels} and "
                    f"{conv_kernel_sizes}"
                )

            self.conv_layers = self._make_conv_layers(
                in_channels=in_channels,
                layer_out_channels=conv_out_channels,
                layer_kernel_sizes=conv_kernel_sizes,
            )
            in_channels = conv_out_channels[-1]
        else:
            self.conv_layers = nn.Identity()

        if final_layer is not None:
            cfg = dict(
                type="Conv2d",
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
            )
            cfg.update(final_layer)
            self.final_layer = build_conv_layer(cfg)
        else:
            self.final_layer = nn.Identity()

    def _make_conv_layers(
        self,
        in_channels: int,
        layer_out_channels: Sequence[int],
        layer_kernel_sizes: Sequence[int],
    ) -> nn.Module:
        """Create convolutional layers by given parameters."""

        layers = []
        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
            padding = (kernel_size - 1) // 2
            cfg = dict(
                type="Conv2d",
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding,
            )
            layers.append(build_conv_layer(cfg))

            if self.use_silu:
                layers.append(nn.InstanceNorm2d(out_channels))
                layers.append(nn.SiLU(inplace=True))
            else:
                layers.append(nn.BatchNorm2d(num_features=out_channels))
                layers.append(nn.ReLU(inplace=True))

            in_channels = out_channels

        return nn.Sequential(*layers)

    def _make_deconv_layers(
        self,
        in_channels: int,
        layer_out_channels: Sequence[int],
        layer_kernel_sizes: Sequence[int],
    ) -> nn.Module:
        """Create deconvolutional layers by given parameters."""

        layers = []
        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
            if kernel_size == 4:
                padding = 1
                output_padding = 0
            elif kernel_size == 3:
                padding = 1
                output_padding = 1
            elif kernel_size == 2:
                padding = 0
                output_padding = 0
            else:
                raise ValueError(
                    f"Unsupported kernel size {kernel_size} for"
                    "deconvlutional layers in "
                    f"{self.__class__.__name__}"
                )
            cfg = dict(
                type="deconv",
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=2,
                padding=padding,
                output_padding=output_padding,
                bias=False,
            )
            layers.append(build_upsample_layer(cfg))

            if self.use_silu:
                layers.append(nn.InstanceNorm2d(out_channels))
                layers.append(nn.SiLU(inplace=True))
            else:
                layers.append(nn.BatchNorm2d(num_features=out_channels))
                layers.append(nn.ReLU(inplace=True))

            in_channels = out_channels

        return nn.Sequential(*layers)

    @property
    def default_init_cfg(self):
        init_cfg = [
            dict(type="Normal", layer=["Conv2d", "ConvTranspose2d"], std=0.001),
            dict(type="Constant", layer="BatchNorm2d", val=1),
            dict(
                type="Constant", layer="InstanceNorm2d", val=1, bias=0
            ),  # Initialize gamma to 1 and beta to 0
        ]
        return init_cfg

    def decode(self, batch_outputs: Union[Tensor, Tuple[Tensor]]):
        """Decode keypoints from outputs.

        Args:
            batch_outputs (Tensor | Tuple[Tensor]): The network outputs of
                a data batch

        Returns:
            List[InstanceData]: A list of InstanceData, each contains the
            decoded pose information of the instances of one data sample.
        """

        def _pack_and_call(args, func):
            if not isinstance(args, tuple):
                args = (args,)
            return func(*args)

        if self.decoder is None:
            raise RuntimeError(
                f"The decoder has not been set in {self.__class__.__name__}. "
                "Please set the decoder configs in the init parameters to "
                "enable head methods `head.predict()` and `head.decode()`"
            )

        batch_output_np = to_numpy(batch_outputs, unzip=True)
        batch_keypoints = []
        batch_scores = []
        for outputs in batch_output_np:
            keypoints, scores = _pack_and_call(outputs, self.decoder.decode)
            batch_keypoints.append(keypoints)
            batch_scores.append(scores)

        preds = [
            InstanceData(keypoints=keypoints, keypoint_scores=scores)
            for keypoints, scores in zip(batch_keypoints, batch_scores)
        ]

        return preds

    def forward(
        self, x: Tensor, decode_kpt: bool = False
    ) -> Tuple[Tensor, Optional[InstanceData]]:
        """Forward the network. The input is multi scale feature maps and the
        output is the heatmap.

        Args:
            x: last hidden state of the Sapiens ViT.

        Returns:
            Tensor: output heatmap.
        """
        x = self.deconv_layers(x)
        x = self.conv_layers(x)
        x = self.final_layer(x)

        if decode_kpt:
            preds = self.decode(x)
        else:
            preds = None

        return x, preds


if __name__ == "__main__":
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    from torchvision.io import read_video, write_video
    from mmengine.device import get_device
    from mmhug.utils.vis_utils import visualize_image_and_heatmap

    device = get_device()
    dtype = torch.bfloat16

    backbone = HF_MODELS.build(
        dict(
            type="SapiensVisionTransformer",
            arch="sapiens_0.3b",
            norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
            img_size=(1024, 768),
            patch_size=16,
            in_channels=3,
            out_indices=-1,
            drop_rate=0.0,
            drop_path_rate=0.0,
            qkv_bias=True,
            norm_cfg=dict(type="LN", eps=1e-6),
            final_norm=True,
            with_cls_token=False,
            frozen_stages=-1,
            interpolate_mode="bicubic",
            layer_scale_init_value=0.0,
            patch_cfg=dict(padding=2),
            layer_cfgs=dict(),
            pre_norm=False,
            out_type="featmap",
            init_cfg=dict(
                type="Pretrained",
                checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
            ),
        )
    ).to(device, dtype)

    backbone.init_weights()
    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(512, 512), heatmap_size=(192, 256), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    ).to(device, dtype)
    heatmap_head.init_weights()
    img = "data/hallo3/hallo3_training_data/videos_cropped_new/fe0eb399d3372546b6437401d707551b.mp4"
    img = read_video(img, output_format="TCHW")[0][:8]
    img = img.float() / 255.0
    img = (img - 0.5) / 0.5
    img = img.to(device, dtype)
    feats = backbone(img, out_type="featmap").last_hidden_state
    heatmap, preds = heatmap_head(feats, decode_kpt=True)
    heatmap = F.interpolate(
        heatmap, size=(256, 256), mode="bilinear", align_corners=False
    )
    print(heatmap.shape)
    keypoints = np.concatenate([p.keypoints for p in preds], axis=0)
    score = np.concatenate([p.keypoint_scores for p in preds], axis=0)
    print(keypoints.shape, score.shape)
    # draw heatmap
    img = img.permute(0, 2, 3, 1).detach().float().cpu().numpy()
    img = img * 0.5 + 0.5
    visualize_image_and_heatmap(
        img,
        heatmap.squeeze(0).detach().float().cpu().numpy(),
        mode="max",
        cmap="hot",
        save_path="image_with_heatmap.mp4",
        dpi=200,
    )
    video = draw_keypoints_sequence(
        img * 255.0, np.concatenate([keypoints, score[..., None]], axis=2)
    )
    write_video("keypoints.mp4", video, fps=25)
