from typing import Dict
import librosa
from mmcv import BaseTransform
from mmhug.registry import TRANSFORMS


@TRANSFORMS.register_module()
class LoadAudio(BaseTransform):
    def __init__(
        self,
        audio_path_key: str = "audio_path",
        sampling_rate: int = 16000,
        mono: bool = True,
    ) -> None:
        super().__init__()
        self.audio_path_key = audio_path_key
        self.sampling_rate = sampling_rate
        self.mono = mono

    def transform(self, results: Dict) -> Dict:
        audio_path = results.pop(self.audio_path_key)
        audio, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=self.mono)
        num_samples = audio.shape[0] if self.mono else audio.shape[1]
        results["audio"] = audio
        ori_metadata = results.get("audio_metadata", {})
        ori_metadata.update(
            dict(
                audio_path=audio_path,
                sr=sr,
                mono=self.mono,
                duration=num_samples / sr,
            )
        )
        return results
