import logging
from os.path import isdir
from typing import Dict, List, Union
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.dataset import Compose
from mmengine import print_log
import os
from typing import Dict, List, Union

import numpy as np
import torch

from mmhug.registry import DATASETS

# Import Transform if it's not imported yet
from mmcv import BaseTransform
import random
from tqdm import tqdm

from mmhug.utils.vis_utils import draw_keypoints_sequence
from torchvision.io import write_video


@DATASETS.register_module(force=True)
class TextVideoAudioKeypointDataset(BaseDataset):
    """
    A custom dataset class for handling text, video, and audio data.
    Inherits from `BaseDataset` provided by MMEngine.

    Attributes:
        data_dir (str): Directory path where the data is stored.
        data_list (List[dict]): A list of parsed annotation data.
        pipeline (Compose): A composition of data transformation operations.
        refetch (bool): Whether to refetch data if an error raised during __getitem__.
          If False, the error will be raised and the whole process will terminate. Otherwise, refetch another istead.

    Args:
        data_dir (str): Directory path where the data is stored.
        anno_file (str): Path to the annotation file.
        pipeline (Union[Dict, Transform]): Data transformation pipeline,
            can be a dictionary or a `Transform` object.
    """

    def __init__(
        self,
        data_dir: str,
        anno_file: str,
        pipeline: Union[Dict, BaseTransform, List[Union[Dict, BaseTransform]]],
        refetch=True,
    ):
        self.data_dir = data_dir
        self.anno_file = anno_file  # Bug fix: Should use ann_file instead of data_list
        self._metainfo = {}

        self.pipeline = Compose(pipeline)
        self.data_list = self.load_data_list()
        self.refetch = refetch

    def load_data_list(self) -> List[dict]:
        """Copied from mmengine.dataset.based_dataset.BaseDataset
        Load annotations from an annotation file named as ``self.ann_file``

        If the annotation file does not follow `OpenMMLab 2.0 format dataset
        <https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html>`_ .
        The subclass must override this method for load annotations. The meta
        information of annotation file will be overwritten :attr:`meta_info`
        and ``meta_info`` argument of constructor.

        Returns:
            list[dict]: A list of annotation.
        """  # noqa: E501
        # `self.ann_file` denotes the absolute annotation file path if
        # `self.root=None` or relative path if `self.root=/path/to/data/`.
        annotations = mmengine.load(self.anno_file)
        if not isinstance(annotations, dict):
            raise TypeError(
                f"The annotations loaded from annotation file "
                f"should be a dict, but got {type(annotations)}!"
            )
        if "data_list" not in annotations or "meta_info" not in annotations:
            raise ValueError("Annotation must have data_list and meta_info keys")

        meta_info = annotations["meta_info"]
        for k, v in meta_info.items():
            self._metainfo.setdefault(k, v)

        raw_data_list = annotations["data_list"]

        is_main = True
        try:
            from mmengine import dist

            is_main = (not dist.is_distributed()) or (dist.get_rank() == 0)
        except ImportError:
            is_main = True

        iterator = (
            tqdm(raw_data_list.values(), desc="Loading data_list")
            if is_main
            else raw_data_list.values()
        )

        data_list = []
        for data_info in iterator:
            if isinstance(data_info, dict):
                data_list.append(data_info)
            elif isinstance(data_info, list):
                for item in data_info:
                    if not isinstance(item, dict):
                        raise TypeError(
                            f"data_info must be list of dict, but got {type(item)}"
                        )
                data_list.extend(data_info)
            else:
                raise TypeError(
                    f"data_info should be a dict or list of dict, but got {type(data_info)}"
                )

        return data_list

    def prepare_data(self, idx: int) -> dict:
        raw_data_info = self.data_list[idx]
        data_info = {
            "video_path": os.path.join(
                self.data_dir, raw_data_info["video_path"]
            ),  # cannot be None
            "video_metadata": {
                "fps": raw_data_info.get("fps"),
            },
            "audio_path": (
                os.path.join(self.data_dir, raw_data_info["audio_path"])
                if raw_data_info.get("audio_path")
                else None
            ),
            "audio_metadata": {
                "sr": raw_data_info.get("sr"),
            },
            "caption_path": (
                os.path.join(self.data_dir, raw_data_info["caption_path"])
                if raw_data_info.get("caption_path")
                else None
            ),
            "caption": raw_data_info.get("caption"),
            "caption_metadata": {
                "caption_path": None,
            },
            "keypoint_path": (
                os.path.join(self.data_dir, raw_data_info["keypoint_path"])
                if raw_data_info.get("keypoint_path")
                else None
            ),
            "keypoint_metadata": {},
            "speech_script": raw_data_info.get("speech_script"),
            "language": raw_data_info.get("language"),
        }
        # support video path in a single video file or a directory of frame iamges
        if os.path.isdir(data_info["video_path"]):
            assert (
                data_info["video_metadata"]["fps"] is not None
            ), "If video is provided in a folder of frames, please provide the FPS of the video in annotation file"
            assert (
                data_info["audio_path"] is not None
            ), "If video is provided in a folder of frames, please provide the audio path"
        return data_info

    def __len__(self) -> int:
        return len(self.data_list)

    def __getitem__(self, idx: int) -> dict:
        """The returned Dict will be like:
        {
            "video": torch.rand([T, C, H, W]),
            "video_metadata": {
                "video_path": video_path,
                "duration": T / fps,
                "num_frames": T,
                "fps": fps,
                "width": W,
                "height": H
            },
            "audio": torch.rand([T]),
            "audio_metadata": {
                "audio_path": audio_path,
                'sr': SR,
                'duration': T,
                'num_frames': T * SR,
            },
            "keypoint": torch.rand([T, K, H, W]),
            "keypoint_metadata: {
                "keypoint_path": keypoint_path,
                "num_frames": T * fps,
                "num_joints: K,
                "height": H,
                "width": W,
            }
            "caption": "A person dressed in white and wearing a black hat is giving a speech",
            "caption_metadata":{
                'caption_path': caption_path    # caption path is None if caption is given in annotation file
            }

        }
        """
        sample = self.prepare_data(idx)

        try:
            sample = self.pipeline(sample)
            return sample

        except Exception as e:
            if not self.refetch:
                raise e
            print_log(
                f"Get error when loading idx={idx}: fetching another instead, error: {e}",
                level=logging.WARNING,
            )
            # EN: Log a warning indicating which index failed and the exception.

            new_idx = random.randint(0, len(self.data_list) - 1)
            if new_idx == idx:
                new_idx = (idx + 1) % len(self.data_list)
            # EN: If by chance the random pick equals the failed idx, shift by 1.

            # EN: Recursively fetch another sample.
            return self.__getitem__(new_idx)

    def full_init(self):
        pass


if __name__ == "__main__":
    from mmengine import init_default_scope

    init_default_scope("mmhug")

    pipeline = [
        dict(
            type="LoadVideoAudioSegmentWithKeypointRef",
            video_path_key="video_path",
            keypoint_path_key="keypoint_path",
            audio_path_key="audio_path",
            filter_min_num_frames=17,
            segment_num_frames=17,
            sampling_rate=16000,
            segment_rule="random",
            video_only=False,
            frame_multiple=8,
            frame_multiple_add=1,
            use_ref_img=True,
            assert_fps=25,
            num_ref_img=17,
            ref_img_rule="random_video",
        ),
        dict(
            type="ResizeVideo",
            video_keys=["video", "ref_img"],
            size_candidates=[(512, 512)],
            keep_ratio=True,
        ),
        dict(
            type="CenterCropVideo",
            video_keys=["video", "ref_img"],
            crop_size=(512, 512),
        ),
        dict(
            type="SapiensKeypoint2Mask",
            mask_area="lower_face",
            mask_expand=(0, 0, 0, 20),
        ),
        dict(
            type="NormalizeVideo",  # w.r.t dinov2
            video_keys=["video", "ref_img"],
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5],
        ),
    ]

    dataset = dict(
        type="TextVideoAudioKeypointDataset",
        data_dir="data/",
        anno_file="data/annotations/test_anno.json",
        pipeline=pipeline,
        refetch=True,
    )

    os.makedirs("output/dataset_test/", exist_ok=True)
    dataset = DATASETS.build(dataset)
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    for i in tqdm(idx):
        data = dataset[i]
        # -1 - 1  t c h w
        video = data["video"]
        ref_img = data["ref_img"]
        keypoint = data["keypoint"]
        mask = data["mask"].unsqueeze(1)

        video = (video * 0.5 + 0.5) * 255 * (1 - mask)
        video = draw_keypoints_sequence(video.permute(0, 2, 3, 1).numpy(), keypoint)
        ref_img = ((ref_img * 0.5 + 0.5) * 255).permute(0, 2, 3, 1).numpy()
        print(video.shape, ref_img.shape)
        video = np.concatenate((video, ref_img), axis=-2)

        write_video("output/dataset_test/{}.mp4".format(i), video, fps=25)
        print("save to output/dataset_test/{}.mp4".format(i))
