from re import S
from typing import Dict
from mmcv import BaseTransform
import numpy as np

from mmhug.registry import TRANSFORMS


@TRANSFORMS.register_module()
class LoadUnitalkerMotion(BaseTransform):
    def __init__(self, motion_path_key: str = "motion_path", min_duration=-1) -> None:
        super().__init__()
        self.motion_path_key = motion_path_key
        self.min_duration = min_duration

    def transform(self, results: Dict) -> Dict:

        cur_metadata = results.get("motion_metadata", {})
        assert "fps" in cur_metadata, f"Missing fps for motion {motion_path}"

        motion_path = results.pop(self.motion_path_key)
        results["motion"] = np.load(motion_path)
        num_frames = results["motion"].shape[0]
        fps = cur_metadata["fps"]

        duration = num_frames / fps
        if duration < self.min_duration:
            raise ValueError(
                f"Motion {motion_path} duration {duration} is shorter than min_duration {self.min_duration}"
            )

        results["motion_metadata"].update(
            dict(
                motion_path=motion_path,
                num_frames=num_frames,
                duration=duration,
            )
        )

        return results
