import re
from typing import List, Union
from mmengine.registry import VISUALIZERS
from mmengine.visualization import Visualizer
from mmengine.structures import BaseDataElement
from .video_visbackend import VideoVisBackend


@VISUALIZERS.register_module()
class VideoAudioVisualizer(Visualizer):

    def __init__(
        self,
        video_keys: Union[List[str], str],
        audio_keys: Union[List[str], str] = None,
        name: str = "visualizer",
        *args,
        **kwargs,
    ) -> None:
        """This visualizer is used to visualize video with (optional) audio.
        Visualizer will load keys within video_keys, and visualize them to disk.
        Multiple keys can be provided, they will be saved to different files(according to fn_key and video_keys).
        If audio_keys is provided, visualizer will load audio from audio_keys, and add them to video.
        If audio_keys is multiple, then the num of audio_keys should be equal to num of video_keys or 1.
        If audio_keys is 1, then the same audio will be added to all videos.
        If audio_keys are multiple, then the corresponding audio will be added to corresponding video.
        """
        super().__init__(name, *args, **kwargs)

        if isinstance(video_keys, str):
            video_keys = [video_keys]
        if isinstance(audio_keys, str):
            audio_keys = [audio_keys]

        if audio_keys is not None and len(audio_keys) == 1 and len(video_keys) > 1:
            audio_keys = audio_keys * len(video_keys)

        assert audio_keys is None or len(video_keys) == len(
            audio_keys
        ), f"audio_keys should be None or 1 or equal to num of video_keys, but got {len(audio_keys)} and {len(video_keys)}"

        self.video_keys = video_keys
        self.audio_keys = audio_keys

    def add_datasample(self, data_sample: BaseDataElement, step=0) -> None:
        """
        data_sample: The output dict of a single sample(not a batch) from model.forward_predict().
        step: For validation, is the current training step. For testing, is 0.
        """
        fn = data_sample.get("video_metadata").get("video_path")
        fn = re.split(r" |/|\\", fn)[-1]
        fn = fn.split(".")[0]

        for idx, video_key in enumerate(self.video_keys):
            video = data_sample[video_key]
            if self.audio_keys is None:
                audio = None
            else:
                audio_key = self.audio_keys[idx]
                if audio_key is None:
                    audio = None
                else:
                    audio = data_sample.get(self.audio_keys[idx])

            for vis_backend in self._vis_backends.values():
                if isinstance(vis_backend, VideoVisBackend):
                    vis_backend.add_image(
                        name=fn,
                        video=video,
                        key=video_key,
                        step=step,
                        audio=audio,
                        audio_sr=16000,
                    )  # type: ignore
