from ctypes import Union
import os
from mmengine.visualization import LocalVisBackend
from mmengine import VISBACKENDS
import os.path as osp
from typing import Optional, Union
import torchvision
import torchaudio

import torch
import numpy as np
from mmhug.utils.io import merge_video_audio


@VISBACKENDS.register_module()
class VideoVisBackend(LocalVisBackend):
    def __init__(
        self,
        save_dir: str,
        img_save_dir: str = "vis_image",
        config_save_file: str = "config.py",
        scalar_save_file: str = "scalars.json",
    ):
        super().__init__(save_dir, img_save_dir, config_save_file, scalar_save_file)

    def add_image(
        self,
        name: str,  # video name
        video: Union[torch.Tensor, np.ndarray],
        key: str = "pred_video",
        step: Optional[int] = None,
        fps: int = 25,
        audio: Optional[Union[torch.Tensor, np.ndarray]] = None,
        audio_sr: int = 16000,
        **kwargs,
    ) -> str:

        save_dir = self._img_save_dir
        if step is not None:
            save_dir = osp.join(save_dir, str(step))
        save_dir = osp.join(save_dir, name)
        os.makedirs(save_dir, exist_ok=True)
        save_path = osp.join(save_dir, f"{key}.mp4")
        if isinstance(video, np.ndarray):
            video = torch.from_numpy(video)

        torchvision.io.write_video(
            save_path,
            video,
            fps=fps,
        )

        # write audio
        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio)
        if audio is not None:
            if audio.ndim == 1:
                audio = audio.unsqueeze(0)
            audio_path = osp.join(save_dir, f"{key}_audio.wav")
            torchaudio.save(audio_path, audio, sample_rate=audio_sr)

            merge_video_audio(
                save_path,
                audio_path,
                osp.join(save_dir, f"{key}_audio.mp4"),
                fps=fps,
                sr=audio_sr,
            )
        return save_path
