from email.mime import audio
from pathlib import Path
import random
from typing import Dict, List, Optional, Tuple, Union

import decord
import torch
import torchaudio
from torchaudio.transforms import Resample
from mmcv import BaseTransform
import imageio.v3 as iio
from mmhug.registry import TRANSFORMS

IMG_POSTFIX = [".jpg", ".jpeg", ".png", ".bmp"]


@TRANSFORMS.register_module()
class LoadVideoAudioSegment(BaseTransform):
    def __init__(
        self,
        video_path_key: str = "video_path",
        audio_path_key: Optional[str] = "audio_path",
        max_num_frames: Optional[int] = None,
        sampling_rate: int = 16000,
        segment_rule: Optional[str] = None,
        video_only: bool = False,
        strict_length: bool = True,
    ) -> None:
        super().__init__()
        self.video_path_key = video_path_key
        self.audio_path_key = audio_path_key
        self.max_num_frames = max_num_frames
        self.sampling_rate = sampling_rate
        self.video_only = video_only
        # If strict_length is True, throw error if the video length is less than max_num_frames
        self.strict_length = strict_length
        if segment_rule is None:
            assert (
                max_num_frames is None or max_num_frames < 0
            ), "If segment_rule is None, max_num_frames must be None or negative"
        else:
            assert (
                max_num_frames is not None and max_num_frames >= 0
            ), "If you set segment_rule, max_num_frames must be a non-negative int"
        assert segment_rule in [
            None,
            "random",
            "center",
            "head",
            "tail",
        ], f"segment_rule must be None or one of 'random','center','head','tail'; got {segment_rule}"
        self.segment_rule = segment_rule

    def _sample_indices(self, total: int) -> List[int]:
        if self.segment_rule is None:
            return list(range(total))
        F = min(self.max_num_frames, total)
        if self.max_num_frames > total:
            if self.strict_length:
                raise ValueError(
                    f"The video is {total} frames long, but target {self.max_num_frames} frames."
                )
            # else:
            #     print_log(
            #         f"The video is {total} frames long, but target {self.max_num_frames} frames, so no cropping will be performed"
            #     )
        start = {
            "random": random.randint(0, total - F),
            "center": (total - F) // 2,
            "head": 0,
            "tail": total - F,
        }[self.segment_rule]
        return list(range(start, start + F))

    def _load_video_file(
        self, filepath: Union[str, Path]
    ) -> Tuple[torch.Tensor, float, List[int]]:
        vr = decord.VideoReader(str(filepath))
        fps = int(vr.get_avg_fps() + 0.5)
        idxs = self._sample_indices(len(vr))
        frames = vr.get_batch(idxs).asnumpy()  # (len, H, W, 3)
        # convert once to torch and permute
        return (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(torch.float32),
            fps,
            idxs,
        )

    def _load_video_folder(
        self, folder: Union[str, Path]
    ) -> Tuple[torch.Tensor, None, List[int]]:
        folder = Path(folder)
        files = sorted(p for p in folder.iterdir() if p.suffix.lower() in IMG_POSTFIX)
        idxs = self._sample_indices(len(files))
        tensors = []
        for i in idxs:
            img = iio.v3.imread(str(files[i]))  # HWC
            tensors.append(torch.from_numpy(img).permute(2, 0, 1).to(torch.float32))
        return torch.stack(tensors, dim=0), None, idxs

    def _load_audio_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
        fps: Optional[float],
    ) -> Tuple[torch.Tensor, int, int, int]:
        """
        1) Read native sample_rate & channel count via torchaudio.info
        2) Compute exact original-sample offset & count for the segment
        3) Load only that slice with torchaudio.load(frame_offset, num_frames)
        4) **Pad** raw wav2d to exactly ori_num samples
        5) Mixdown (if stereo) and Resample
        6) Pad/trim final wav to exactly target_num samples
        """
        # --- 1) native metadata ---
        info = torchaudio.info(str(filepath))
        ori_sr = info.sample_rate
        ori_ch = info.num_channels

        # --- 2) compute offset & count ---
        if fps is None or not indices:
            # fallback to full file
            start_offset = 0
            ori_num = None
        else:
            num_frames = len(indices)
            duration_sec = num_frames / fps
            start_sec = indices[0] / fps

            start_offset = int(round(start_sec * ori_sr))
            ori_num = int(round(duration_sec * ori_sr))

        # --- 3) load slice ---
        if ori_num is None:
            wav2d, _ = torchaudio.load(str(filepath))  # [C, N_full]
            ori_num = wav2d.size(1)
        else:
            wav2d, _ = torchaudio.load(
                str(filepath),
                frame_offset=start_offset,
                num_frames=ori_num,
            )  # [C, ≤ori_num]

        # --- 4) pad raw wav2d so both channels have length = ori_num ---
        cur_len = wav2d.size(1)
        if cur_len < ori_num:
            pad_amt = ori_num - cur_len
            pad = wav2d.new_zeros((ori_ch, pad_amt))
            wav2d = torch.cat([wav2d, pad], dim=1)

        # --- 5) mixdown to mono & resample ---
        if ori_ch == 2:
            wav = wav2d[0]
            # L, R = wav2d[0], wav2d[1]
            # wav = torch.sqrt((L.pow(2) + R.pow(2)) / 2.0)
        else:
            wav = wav2d.squeeze(0)

        final_sr = ori_sr
        if ori_sr != self.sampling_rate:
            resampler = Resample(orig_freq=ori_sr, new_freq=self.sampling_rate)
            wav = resampler(wav.unsqueeze(0)).squeeze(0)
            final_sr = self.sampling_rate

        # --- 6) pad/trim final to exact target ---
        if fps is not None and indices:
            target_num = int(round((len(indices) / fps) * final_sr))
            cur = wav.size(0)
            if cur < target_num:
                pad = wav.new_zeros((target_num - cur,))
                wav = torch.cat([wav, pad], dim=0)
            elif cur > target_num:
                wav = wav[:target_num]

        return wav, ori_sr, ori_ch, final_sr

    def transform(self, results: Dict) -> Dict:
        video_path = Path(results.pop(self.video_path_key))
        if video_path.is_dir():
            video, _, indices = self._load_video_folder(video_path)
            fps = results.get("video_metadata", {}).get("fps")
            assert fps is not None, f"Missing fps for folder {video_path}"
        else:
            video, fps, indices = self._load_video_file(video_path)

        results["video"] = video  # [T, C, H, W]
        results["video_metadata"] = dict(
            video_path=str(video_path),
            num_frames=video.shape[0],
            fps=fps,
            duration=video.shape[0] / fps,
            height=video.shape[-2],
            width=video.shape[-1],
            ori_height=video.shape[-2],
            ori_width=video.shape[-1],
            frame_indices=tuple(indices),
        )

        if not self.video_only:
            if self.audio_path_key:
                audio_path = results.pop(self.audio_path_key, None)
                if audio_path is None:
                    audio_path = video_path
            else:
                audio_path = video_path
            wav, ori_sr, ori_ch, sr = self._load_audio_data(audio_path, indices, fps)

            results["audio"] = wav  # [N]
            results["audio_metadata"] = dict(
                audio_path=audio_path,
                ori_sr=ori_sr,
                ori_channels=ori_ch,
                sr=sr,
                num_samples=wav.numel(),
                duration=wav.numel() / sr,
            )
        return results


if __name__ == "__main__":
    print("=== LoadVideoAudioSegment Unit Tests (please set actual paths) ===\n")

    def print_ok(message: str):
        print(f"✔ {message}")

    # --- 1) _sample_indices tests ---
    loader_default = LoadVideoAudioSegment(video_only=True)
    assert loader_default._sample_indices(5) == [0, 1, 2, 3, 4]
    print_ok("_sample_indices default behavior")

    loader_head = LoadVideoAudioSegment(
        max_num_frames=3, segment_rule="head", video_only=True
    )
    assert loader_head._sample_indices(5) == [0, 1, 2]
    print_ok("_sample_indices 'head' rule")

    loader_tail = LoadVideoAudioSegment(
        max_num_frames=3, segment_rule="tail", video_only=True
    )
    assert loader_tail._sample_indices(5) == [2, 3, 4]
    print_ok("_sample_indices 'tail' rule")

    loader_center = LoadVideoAudioSegment(
        max_num_frames=3, segment_rule="center", video_only=True
    )
    assert loader_center._sample_indices(7) == [2, 3, 4]
    print_ok("_sample_indices 'center' rule")

    indices_random = LoadVideoAudioSegment(
        max_num_frames=3, segment_rule="random", video_only=True
    )._sample_indices(7)
    assert len(indices_random) == 3 and all(0 <= i < 7 for i in indices_random)
    print_ok("_sample_indices 'random' rule")

    # --- 2) Parameter validation tests ---
    try:
        LoadVideoAudioSegment(max_num_frames=None, segment_rule="head")
        raise RuntimeError(
            "Expected AssertionError when segment_rule is set but max_num_frames is None"
        )
    except AssertionError:
        print_ok("Parameter check: segment_rule requires max_num_frames")

    try:
        LoadVideoAudioSegment(max_num_frames=3, segment_rule="invalid_rule")
        raise RuntimeError("Expected AssertionError for invalid segment_rule")
    except AssertionError:
        print_ok("Parameter check: invalid segment_rule rejected")

    # --- Paths to fill in ---
    VIDEO_PATH = "demo_assets/___OJkS9RK0_0.mp4"  # e.g. "/path/to/your/video.mp4"
    AUDIO_PATH = "data/celebv-hq/audio/resampled____OJkS9RK0_0.wav"  # e.g. "/path/to/your/audio.wav"
    FRAME_FOLDER = "data/celebv-hq/videos_seg_ms/___OJkS9RK0_0"  # e.g. "/path/to/your/frames_folder"

    # --- 3) Test loading video file only ---
    if VIDEO_PATH:
        loader_v = LoadVideoAudioSegment(video_only=True)
        result_v = loader_v.transform({"video_path": VIDEO_PATH, "video_metadata": {}})
        video_shape = result_v["video"].shape  # Expect [T, C, H, W]
        print(f"✔ Video-only loaded with shape: {video_shape}")
    else:
        print("⚠ VIDEO_PATH not set, skipping video-only load test")

    # --- 4) Test loading video with external audio ---
    if VIDEO_PATH and AUDIO_PATH:
        loader_va = LoadVideoAudioSegment(video_only=False, audio_path_key="audio_path")
        result_va = loader_va.transform(
            {"video_path": VIDEO_PATH, "audio_path": AUDIO_PATH, "video_metadata": {}}
        )
        video_shape2 = result_va["video"].shape
        audio_shape = result_va["audio"].shape  # Expect [C, N_samples]
        print(f"✔ Video+audio loaded: video {video_shape2}, audio {audio_shape}")
    else:
        print("⚠ VIDEO_PATH or AUDIO_PATH not set, skipping video+audio test")

    # --- 5) Test loading from frame folder ---
    if FRAME_FOLDER:
        loader_f = LoadVideoAudioSegment(video_only=True)
        result_f = loader_f.transform(
            {
                "video_path": FRAME_FOLDER,
                "video_metadata": {"fps": 24},  # Frame folder requires fps provided
            }
        )
        video_shape3 = result_f["video"].shape
        print(f"✔ Frames-folder video loaded with shape: {video_shape3}")
    else:
        print("⚠ FRAME_FOLDER not set, skipping frame-folder test")

    print("\n=== All tests completed ===")
