import logging
from os.path import isdir
from typing import Dict, List, Union
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.dataset import Compose
from mmengine import print_log
import os
from typing import Dict, List, Union

from mmhug.registry import DATASETS

# Import Transform if it's not imported yet
from mmcv import BaseTransform
import random
from tqdm import tqdm


@DATASETS.register_module()
class TextVideoAudioDataset(BaseDataset):
    """
    A custom dataset class for handling text, video, and audio data.
    Inherits from `BaseDataset` provided by MMEngine.

    Attributes:
        data_dir (str): Directory path where the data is stored.
        data_list (List[dict]): A list of parsed annotation data.
        pipeline (Compose): A composition of data transformation operations.
        refetch (bool): Whether to refetch data if an error raised during __getitem__.
          If False, the error will be raised and the whole process will terminate. Otherwise, refetch another istead.

    Args:
        data_dir (str): Directory path where the data is stored.
        anno_file (str): Path to the annotation file.
        pipeline (Union[Dict, Transform]): Data transformation pipeline,
            can be a dictionary or a `Transform` object.
    """

    def __init__(
        self,
        data_dir: str,
        anno_file: str,
        pipeline: Union[Dict, BaseTransform, List[Union[Dict, BaseTransform]]],
        refetch=True,
    ):
        self.data_dir = data_dir
        self.anno_file = anno_file  # Bug fix: Should use ann_file instead of data_list
        self._metainfo = {}

        self.pipeline = Compose(pipeline)
        self.data_list = self.load_data_list()
        self.refetch = refetch

    def load_data_list(self) -> List[dict]:
        """Copied from mmengine.dataset.based_dataset.BaseDataset
        Load annotations from an annotation file named as ``self.ann_file``

        If the annotation file does not follow `OpenMMLab 2.0 format dataset
        <https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html>`_ .
        The subclass must override this method for load annotations. The meta
        information of annotation file will be overwritten :attr:`meta_info`
        and ``meta_info`` argument of constructor.

        Returns:
            list[dict]: A list of annotation.
        """  # noqa: E501
        # `self.ann_file` denotes the absolute annotation file path if
        # `self.root=None` or relative path if `self.root=/path/to/data/`.
        annotations = mmengine.load(self.anno_file)
        if not isinstance(annotations, dict):
            raise TypeError(
                f"The annotations loaded from annotation file "
                f"should be a dict, but got {type(annotations)}!"
            )
        if "data_list" not in annotations or "meta_info" not in annotations:
            raise ValueError("Annotation must have data_list and meta_info keys")

        meta_info = annotations["meta_info"]
        for k, v in meta_info.items():
            self._metainfo.setdefault(k, v)

        raw_data_list = annotations["data_list"]

        is_main = True
        try:
            from mmengine import dist

            is_main = (not dist.is_distributed()) or (dist.get_rank() == 0)
        except ImportError:
            is_main = True

        iterator = (
            tqdm(raw_data_list.values(), desc="Loading data_list")
            if is_main
            else raw_data_list.values()
        )

        data_list = []
        for data_info in iterator:
            if isinstance(data_info, dict):
                data_list.append(data_info)
            elif isinstance(data_info, list):
                for item in data_info:
                    if not isinstance(item, dict):
                        raise TypeError(
                            f"data_info must be list of dict, but got {type(item)}"
                        )
                data_list.extend(data_info)
            else:
                raise TypeError(
                    f"data_info should be a dict or list of dict, but got {type(data_info)}"
                )

        return data_list

    def prepare_data(self, idx: int) -> dict:
        raw_data_info = self.data_list[idx]
        data_info = {
            "video_path": os.path.join(
                self.data_dir, raw_data_info["video_path"]
            ),  # cannot be None
            "video_metadata": {
                "fps": raw_data_info.get("fps"),
            },
            "audio_path": (
                os.path.join(self.data_dir, raw_data_info["audio_path"])
                if raw_data_info.get("audio_path")
                else None
            ),
            "audio_metadata": {
                "sr": raw_data_info.get("sr"),
            },
            "caption_path": (
                os.path.join(self.data_dir, raw_data_info["caption_path"])
                if raw_data_info.get("caption_path")
                else None
            ),
            "caption": raw_data_info.get("caption"),
            "caption_metadata": {
                "caption_path": None,
            },
        }
        # support video path in a single video file or a directory of frame iamges
        if os.path.isdir(data_info["video_path"]):
            assert (
                data_info["video_metadata"]["fps"] is not None
            ), "If video is provided in a folder of frames, please provide the FPS of the video in annotation file"
            assert (
                data_info["audio_path"] is not None
            ), "If video is provided in a folder of frames, please provide the audio path"
        return data_info

    def __len__(self) -> int:
        return len(self.data_list)

    def __getitem__(self, idx: int) -> dict:
        """The returned Dict will be like:
        {
            "video": torch.rand([T, C, H, W]),
            "video_metadata": {
                "video_path": video_path,
                "duration": T / fps,
                "num_frames": T,
                "fps": fps,
                "width": W,
                "height": H
            },
            "audio": torch.rand([T]),
            "audio_metadata": {
                "audio_path": audio_path,
                'sr': SR,
                'duration': T,
                'num_frames': T * SR,
            },
            "caption": "A person dressed in white and wearing a black hat is giving a speech",
            "caption_metadata":{
                'caption_path': caption_path    # caption path is None if caption is given in annotation file
            }

        }
        """
        sample = self.prepare_data(idx)

        try:
            sample = self.pipeline(sample)
            return sample

        except Exception as e:
            if not self.refetch:
                raise e
            print_log(
                f"Get error when loading idx={idx}: fetching another instead, error: {e}",
                level=logging.WARNING,
            )
            # EN: Log a warning indicating which index failed and the exception.

            new_idx = random.randint(0, len(self.data_list) - 1)
            if new_idx == idx:
                new_idx = (idx + 1) % len(self.data_list)
            # EN: If by chance the random pick equals the failed idx, shift by 1.

            # EN: Recursively fetch another sample.
            return self.__getitem__(new_idx)

    def full_init(self):
        pass
