from math import e
import random
from pathlib import Path
import subprocess
from typing import List, Optional, Tuple, Union

import decord
from einops import rearrange
import torch
import torchaudio
from torchaudio.transforms import Resample
import imageio.v3 as iio

from mmcv import TRANSFORMS, BaseTransform  # adjust import to your project structure
import sys
import os

sys.path.append(os.curdir)
from mmhug.datasets.transforms.load_video import (
    IMG_POSTFIX,
)  # list of allowed image file extensions


@TRANSFORMS.register_module()
class LoadVideoAudioSegmentWithRef(BaseTransform):
    """
    Load a segment of video frames (and optionally audio) with support for:
      1) cropping the frame count to be a multiple of a given integer;
      2) optionally loading a single “reference” image frame outside or just before the segment.

    Args:
        video_path_key (str):
            Key in `results` dict where the video file or folder path is stored.
        audio_path_key (Optional[str]):
            Key in `results` dict for audio path. If None, `video_path_key` is used.
        max_num_frames (Optional[int]):
            Maximum number of frames to crop. If None or negative, no cropping is applied.
              If strict_length is False, the resulted video is determined by min(len(video), max_num_frames)
        sampling_rate (int):
            Target audio sampling rate in Hz.
        segment_rule (Optional[str]):
            One of {None, "random", "center", "head", "tail"}:
              - None: keep all frames (no cropping);
              - "random": select a random contiguous block;
              - "center": center block;
              - "head": first block;
              - "tail": last block.
        video_only (bool):
            If True, do not attempt to load audio.
        strict_length (bool):
            If True, raise an error when video length < `max_num_frames`.
            Otherwise, just use full length.
        frame_multiple (int):
            If >1, after determining target length F = min(max_num_frames, total_frames),
            round F down to the nearest multiple of this value. If <=1, ignore.
        use_ref_img (bool):
            Whether to also load one reference image frame.
        ref_img_rule (Optional[str]):
            If `use_ref_img` is True, must be one of:
              - "random_out": pick a random frame outside the selected segment;
              - "prev": pick the frame immediately preceding the segment start.
    """

    def __init__(
        self,
        # keys for loading
        video_path_key: str = "video_path",
        audio_path_key: Optional[str] = "audio_path",
        # filter rule, videos with length in [filter_min_num_frames, filter_max_num_frames] will be kept
        filter_max_num_frames: Optional[int] = 1e7,
        filter_min_num_frames: Optional[int] = 0,
        # segment rule, the video will be segmented into clips with no more than segment_max_num_frames frames.
        segment_num_frames: Optional[int] = None,
        segment_rule: Optional[str] = None,
        frame_multiple: int = 8,
        frame_multiple_add: int = 1,
        # audio sampling rate
        sampling_rate: int = 16000,
        # whether load audio
        video_only: bool = False,
        # The cropped video length will be rounded down to the nearest multiple of this value.
        use_ref_img: bool = False,
        ref_img_rule: Optional[str] = None,
        assert_fps: Optional[int] = 25,
    ) -> None:
        super().__init__()
        self.video_path_key = video_path_key
        self.audio_path_key = audio_path_key

        self.filter_max_num_frames = filter_max_num_frames
        self.filter_min_num_frames = filter_min_num_frames

        self.segment_num_frames = segment_num_frames
        self.sampling_rate = sampling_rate
        self.video_only = video_only

        # Frame-multiple cropping parameter
        self.frame_multiple = frame_multiple
        self.frame_multiple_add = frame_multiple_add

        # Reference-image logic
        self.use_ref_img = use_ref_img
        if use_ref_img:
            assert ref_img_rule in (
                "random",
                "random_out",
                "prev",
            ), "ref_img_rule must be 'random_out' or 'prev' when use_ref_img=True"
        self.ref_img_rule = ref_img_rule

        # Validate segment_rule vs. max_num_frames
        if segment_rule is None:
            assert (
                self.segment_num_frames is None or self.segment_num_frames < 0
            ), "If segment_rule is None, the entire video will be used as a segment, so please set segment_num_frames to None or negative"
        else:
            assert (
                self.segment_num_frames is not None and self.segment_num_frames >= 0
            ), "If you set segment_rule, segment_num_frames must be a non-negative int"

        assert segment_rule in [
            None,
            "random",
            "center",
            "head",
            "tail",
        ], f"segment_rule must be None or one of 'random','center','head','tail'; got {segment_rule}"
        self.segment_rule = segment_rule
        self.assert_fps = assert_fps

    def _sample_indices(self, total: int) -> List[int]:
        """
        Determine which frame indices to sample from a video of length `total`.

        Ensures that, if use_ref_img=True, at least one frame outside the segment remains.
        If segment would cover all frames, iteratively reduce F by frame_multiple until
        a candidate remains or fail.
        """
        # Case: no cropping requested
        if self.segment_rule is None:
            if self.use_ref_img and total <= 1:
                raise ValueError(
                    f"Video too short ({total} frames) to supply a reference image."
                )
            return list(range(total))

        # Base target frame count
        F = min(self.segment_num_frames, total)

        # Enforce strict length if requested
        if total < self.filter_min_num_frames:
            raise ValueError(
                f"The video is too short({total} frames), lower than {self.filter_min_num_frames} limit."
            )
        elif total >= self.filter_max_num_frames:
            raise ValueError(
                f"The video is too long ({total} frames), over than {self.filter_max_num_frames} limit"
            )

        # Apply multiple-rounding
        if self.frame_multiple > 1:
            ori_len = F
            F = (
                F - self.frame_multiple_add
            ) // self.frame_multiple * self.frame_multiple + self.frame_multiple_add
            assert (
                F >= 0 and F >= self.filter_min_num_frames
            ), f"Got {ori_len} frames input, cannot find a possible segment length with format {self.frame_multiple} * n + {self.frame_multiple_add} and longer than {self.filter_min_num_frames}"

        # If reference image required, ensure at least one outside-frame candidate
        if self.use_ref_img and self.ref_img_rule != "random":
            assert (
                F < total
            ), f"Cannot get an reference img outside segment, because the segment length is {F} but the entire video length is {total}"

        # Choose start index
        start = {
            "random": random.randint(0, total - F),
            "center": (total - F) // 2,
            "head": 0,
            "tail": total - F,
        }[self.segment_rule]

        return list(range(start, start + F))

    def _load_video_file(
        self, filepath: Union[str, Path]
    ) -> Tuple[torch.Tensor, float, List[int]]:
        vr = decord.VideoReader(str(filepath))
        fps = int(vr.get_avg_fps() + 0.5)
        if self.assert_fps is not None:
            assert (
                fps == self.assert_fps
            ), f"{filepath} has fps {fps} != {self.assert_fps}"
        idxs = self._sample_indices(len(vr))
        frames = vr.get_batch(idxs).asnumpy()  # (len, H, W, 3)
        # convert once to torch and permute
        return (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(torch.float32),
            fps,
            idxs,
        )

    def _load_video_folder(
        self, folder: Union[str, Path]
    ) -> Tuple[torch.Tensor, None, List[int]]:
        folder = Path(folder)
        files = sorted(p for p in folder.iterdir() if p.suffix.lower() in IMG_POSTFIX)
        total = len(files)
        indices = self._sample_indices(total)
        tensors = []
        for idx in indices:
            img = iio.imread(str(files[idx]))
            tensors.append(torch.from_numpy(img).permute(2, 0, 1).float())
        return torch.stack(tensors, dim=0), None, indices

    def _load_audio_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
        fps: Optional[float],
    ) -> Tuple[torch.Tensor, int, int, int]:
        """
        1) Read native sample_rate & channel count via torchaudio.info
        2) Compute exact original-sample offset & count for the segment
        3) Load only that slice with torchaudio.load(frame_offset, num_frames)
        4) **Pad** raw wav2d to exactly ori_num samples
        5) Mixdown (if stereo) and Resample
        6) Pad/trim final wav to exactly target_num samples
        """
        try:
            # --- 1) native metadata ---
            info = torchaudio.info(str(filepath))
            ori_sr = info.sample_rate
            ori_ch = info.num_channels

            # --- 2) compute offset & count ---
            if fps is None or not indices:
                # fallback to full file
                start_offset = 0
                ori_num = None
            else:
                num_frames = len(indices)
                duration_sec = num_frames / fps
                start_sec = indices[0] / fps

                start_offset = int(round(start_sec * ori_sr))
                ori_num = int(round(duration_sec * ori_sr))

            # --- 3) load slice ---
            if ori_num is None:
                wav2d, _ = torchaudio.load(str(filepath))  # [C, N_full]
                ori_num = wav2d.size(1)
            else:
                wav2d, _ = torchaudio.load(
                    str(filepath),
                    frame_offset=start_offset,
                    num_frames=ori_num,
                )  # [C, ≤ori_num]

            # --- 4) pad raw wav2d so both channels have length = ori_num ---
            cur_len = wav2d.size(1)
            if cur_len < ori_num:
                pad_amt = ori_num - cur_len
                pad = wav2d.new_zeros((ori_ch, pad_amt))
                wav2d = torch.cat([wav2d, pad], dim=1)

            # --- 5) mixdown to mono & resample ---
            if ori_ch == 2:
                # wav = wav2d[0]
                # L, R = wav2d[0], wav2d[1]
                # wav = torch.sqrt((L.pow(2) + R.pow(2)) / 2.0)
                wav = torch.mean(wav2d, dim=0)
            else:
                wav = wav2d.squeeze(0)

            final_sr = ori_sr
            if ori_sr != self.sampling_rate:
                resampler = Resample(orig_freq=ori_sr, new_freq=self.sampling_rate)
                wav = resampler(wav.unsqueeze(0)).squeeze(0)
                final_sr = self.sampling_rate

            # --- 6) pad/trim final to exact target ---
            if fps is not None and indices:
                target_num = int(round((len(indices) / fps) * final_sr))
                cur = wav.size(0)
                if cur < target_num:
                    pad = wav.new_zeros((target_num - cur,))
                    wav = torch.cat([wav, pad], dim=0)
                elif cur > target_num:
                    wav = wav[:target_num]
            return wav, ori_sr, ori_ch, final_sr
        except Exception as e:
            raise ValueError(f"Error loading audio {filepath}: {e}")

    def transform(self, results: dict) -> dict:
        video_path = Path(results.pop(self.video_path_key))
        if video_path.is_dir():
            video, fps, indices = self._load_video_folder(video_path)
            fps = results.get("video_metadata", {}).get("fps", fps)
            assert fps is not None, f"Missing fps for folder {video_path}"
        else:
            video, fps, indices = self._load_video_file(video_path)

        results["video"] = video
        results["video_metadata"] = dict(
            video_path=str(video_path),
            num_frames=video.shape[0],
            fps=fps,
            duration=(video.shape[0] / fps) if fps else None,
            height=video.shape[2],
            width=video.shape[3],
            frame_indices=tuple(indices),
        )

        # Load reference image if needed
        if self.use_ref_img:
            k = 1
            # Determine total frames for reference pool
            if video_path.is_dir():
                total = len(
                    [p for p in video_path.iterdir() if p.suffix.lower() in IMG_POSTFIX]
                )
            else:
                total = len(decord.VideoReader(str(video_path)))

            # Compute candidate set and pick ref_idx
            if self.ref_img_rule == "random_out":
                pool = set(range(total)) - set(indices)
                ref_idx: list[int] = random.choices(list(pool), k=k) if pool else None
            elif self.ref_img_rule == "random":
                pool = set(range(total))
                ref_idx: list[int] = random.choices(list(pool), k=k)
            else:
                ref_idx: list[int] = [indices[0] - 1] if indices[0] > 0 else None

            if ref_idx is None:
                raise ValueError(
                    f"Unable to find reference frame with rule={self.ref_img_rule}, dataset size={total}, segment indices range={indices}"
                )

            # Load the reference frame
            if video_path.is_dir():
                files = sorted(
                    p for p in video_path.iterdir() if p.suffix.lower() in IMG_POSTFIX
                )
                ref_imgs = [iio.imread(str(files[idx])) for idx in ref_idx]
            else:
                vr = decord.VideoReader(str(video_path))
                ref_imgs = vr.get_batch(ref_idx).asnumpy()

            ref_imgs = torch.tensor(ref_imgs).float()
            ref_imgs = rearrange(ref_imgs, "k h w c -> k c h w")
            results["ref_img"] = ref_imgs  # [1, C, H, W]
            results["video_metadata"]["ref_indice"] = tuple(ref_idx)

        if not self.video_only:
            if self.audio_path_key:
                audio_path = results.pop(self.audio_path_key, None)
                if audio_path is None:
                    audio_path = video_path
            else:
                audio_path = video_path
            wav, ori_sr, ori_ch, sr = self._load_audio_data(audio_path, indices, fps)

            results["audio"] = wav
            results["audio_metadata"] = dict(
                audio_path=audio_path,
                ori_sr=ori_sr,
                ori_channels=ori_ch,
                sr=sr,
                num_samples=wav.numel(),
                duration=wav.numel() / sr,
            )
        return results


if __name__ == "__main__":
    demo_path = "data/celebv-hq/videos/__lRwnjxeCg_4.mp4"

    # Common transform settings
    transform_kwargs = dict(
        video_path_key="video_path",
        audio_path_key="audio_path",
        max_num_frames=256,
        segment_rule="random",  # pick a random contiguous block of frames
        frame_multiple=32,  # segment length must be a multiple of 32
        use_ref_img=True,
        video_only=False,
        sampling_rate=16000,
        strict_length=False,  # allow shorter clips if needed
    )

    for rule in ("random_out", "prev"):
        print(f"\n--- Testing ref_img_rule = {rule} ---")
        # 1) Instantiate
        loader = LoadVideoAudioSegmentWithRef(
            **transform_kwargs,
            ref_img_rule=rule,
        )

        # 2) Run transform
        results = loader.transform({"video_path": demo_path, "audio_path": demo_path})

        video: torch.Tensor = results["video"]  # [T, C, H, W], float32
        wav: torch.Tensor = results["audio"]  # [N], float32
        fps: float = results["video_metadata"]["fps"]
        sr = results["audio_metadata"]["sr"]
        ref_img: torch.Tensor = results["ref_img"]  # [1, C, H, W], float32
        ref_idx = results["video_metadata"]["ref_indice"]
        print(f"Selected ref_indice = {ref_idx}, clip length = {video.shape[0]} frames")

        # 3) Save reference image
        out_png = f"output_{rule}.png"
        # squeeze batch, permute back to HWC, convert to uint8
        np_ref = (
            ref_img.squeeze(0)  # [C, H, W]
            .permute(1, 2, 0)  # [H, W, C]
            .clamp(0, 255)
            .byte()
            .numpy()
        )
        iio.imwrite(out_png, np_ref)
        print(f"→ ref_img saved to {out_png}")

        # 4) Save video+audio to MP4
        # Prepare video frames as uint8 HWC sequence
        np_vid = video.permute(0, 2, 3, 1).clamp(0, 255).byte().numpy()  # [T, H, W, C]
        np_wav = wav.numpy()  # [N], float32
        # 1) Write video-only clip
        tmp_vid = "temp_clip.mp4"
        iio.imwrite(
            tmp_vid,
            np_vid,
            fps=fps,
            codec="libx264",
            ffmpeg_log_level="error",  # suppress too‑much logging
        )

        # 2) Write audio-only WAV
        tmp_wav = "temp_audio.wav"
        # make sure wav is in shape [1, N]
        wav_tensor = torch.from_numpy(np_wav).unsqueeze(0)
        torchaudio.save(tmp_wav, wav_tensor, sample_rate=sr)

        # 3) Mux via ffmpeg CLI
        out_mp4 = f"output_{rule}.mp4"
        subprocess.run(
            [
                "ffmpeg",
                "-y",
                "-i",
                tmp_vid,
                "-i",
                tmp_wav,
                "-c:v",
                "copy",  # keep H.264 video as-is
                "-c:a",
                "aac",  # re‐encode audio to AAC
                out_mp4,
            ],
            check=True,
        )

        print(f"✅ Combined video+audio saved to {out_mp4}")

        # (Optionally) clean up temp files:
        Path(tmp_vid).unlink(missing_ok=True)
        Path(tmp_wav).unlink(missing_ok=True)
