from email.mime import audio
from pathlib import Path
import random
from typing import Dict, List, Optional, Tuple, Union

import decord
import numpy as np
import torch
import torchaudio
from torchaudio.transforms import Resample
from mmcv import BaseTransform, imread
from mmengine import print_log
import imageio.v3 as iio
from mmhug.registry import TRANSFORMS

IMG_POSTFIX = [".jpg", ".jpeg", ".png", ".bmp"]


@TRANSFORMS.register_module()
class LoadVideoAudioWithHeatmapSegment(BaseTransform):
    def __init__(
        self,
        video_path_key: str = "video_path",
        heatmap_path_key: str = "heatmap_path",
        audio_path_key: Optional[str] = "audio_path",
        max_num_frames: Optional[int] = None,
        sampling_rate: int = 16000,
        segment_rule: Optional[str] = None,
        video_only: bool = False,
        strict_length: bool = True,
    ) -> None:
        super().__init__()
        self.video_path_key = video_path_key
        self.audio_path_key = audio_path_key
        self.heatmap_path_key = heatmap_path_key

        self.max_num_frames = max_num_frames
        self.sampling_rate = sampling_rate
        self.video_only = video_only
        # If strict_length is True, throw error if the video length is less than max_num_frames
        self.strict_length = strict_length
        if segment_rule is None:
            assert (
                max_num_frames is None or max_num_frames < 0
            ), "If segment_rule is None, max_num_frames must be None or negative"
        else:
            assert (
                max_num_frames is not None and max_num_frames >= 0
            ), "If you set segment_rule, max_num_frames must be a non-negative int"
        assert segment_rule in [
            None,
            "random",
            "center",
            "head",
            "tail",
        ], f"segment_rule must be None or one of 'random','center','head','tail'; got {segment_rule}"
        self.segment_rule = segment_rule

    def _sample_indices(self, total: int) -> List[int]:
        if self.segment_rule is None:
            return list(range(total))
        F = min(self.max_num_frames, total)
        if self.max_num_frames > total:
            if self.strict_length:
                raise ValueError(
                    f"The video is {total} frames long, but target {self.max_num_frames} frames."
                )
            # else:
            #     print_log(
            #         f"The video is {total} frames long, but target {self.max_num_frames} frames, so no cropping will be performed"
            #     )
        start = {
            "random": random.randint(0, total - F),
            "center": (total - F) // 2,
            "head": 0,
            "tail": total - F,
        }[self.segment_rule]
        return list(range(start, start + F))

    def _load_video_file(
        self, filepath: Union[str, Path]
    ) -> Tuple[torch.Tensor, float, List[int]]:
        vr = decord.VideoReader(str(filepath))
        fps = int(vr.get_avg_fps() + 0.5)
        idxs = self._sample_indices(len(vr))
        frames = vr.get_batch(idxs).asnumpy()  # (len, H, W, 3)
        # convert once to torch and permute
        return (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(torch.float32),
            fps,
            idxs,
        )

    def _load_video_folder(
        self, folder: Union[str, Path]
    ) -> Tuple[torch.Tensor, None, List[int]]:
        folder = Path(folder)
        files = sorted(p for p in folder.iterdir() if p.suffix.lower() in IMG_POSTFIX)
        idxs = self._sample_indices(len(files))
        tensors = []
        for i in idxs:
            img = iio.v3.imread(str(files[i]))  # HWC
            tensors.append(torch.from_numpy(img).permute(2, 0, 1).to(torch.float32))
        return torch.stack(tensors, dim=0), None, idxs

    def _load_heatmap_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
    ):
        # T K H W
        heatmap = torch.from_numpy(np.load(filepath))
        heatmap = heatmap[indices]
        return heatmap

    def _load_audio_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
        fps: Optional[float],
    ) -> Tuple[torch.Tensor, int, int, int]:
        """
        1) Read native sample_rate & channel count via torchaudio.info
        2) Compute exact original-sample offset & count for the segment
        3) Load only that slice with torchaudio.load(frame_offset, num_frames)
        4) **Pad** raw wav2d to exactly ori_num samples
        5) Mixdown (if stereo) and Resample
        6) Pad/trim final wav to exactly target_num samples
        """
        # --- 1) native metadata ---
        info = torchaudio.info(str(filepath))
        ori_sr = info.sample_rate
        ori_ch = info.num_channels

        # --- 2) compute offset & count ---
        if fps is None or not indices:
            # fallback to full file
            start_offset = 0
            ori_num = None
        else:
            num_frames = len(indices)
            duration_sec = num_frames / fps
            start_sec = indices[0] / fps

            start_offset = int(round(start_sec * ori_sr))
            ori_num = int(round(duration_sec * ori_sr))

        # --- 3) load slice ---
        if ori_num is None:
            wav2d, _ = torchaudio.load(str(filepath))  # [C, N_full]
            ori_num = wav2d.size(1)
        else:
            wav2d, _ = torchaudio.load(
                str(filepath),
                frame_offset=start_offset,
                num_frames=ori_num,
            )  # [C, ≤ori_num]

        # --- 4) pad raw wav2d so both channels have length = ori_num ---
        cur_len = wav2d.size(1)
        if cur_len < ori_num:
            pad_amt = ori_num - cur_len
            pad = wav2d.new_zeros((ori_ch, pad_amt))
            wav2d = torch.cat([wav2d, pad], dim=1)

        # --- 5) mixdown to mono & resample ---
        if ori_ch == 2:
            wav = wav2d[0]
            # L, R = wav2d[0], wav2d[1]
            # wav = torch.sqrt((L.pow(2) + R.pow(2)) / 2.0)
        else:
            wav = wav2d.squeeze(0)

        final_sr = ori_sr
        if ori_sr != self.sampling_rate:
            resampler = Resample(orig_freq=ori_sr, new_freq=self.sampling_rate)
            wav = resampler(wav.unsqueeze(0)).squeeze(0)
            final_sr = self.sampling_rate

        # --- 6) pad/trim final to exact target ---
        if fps is not None and indices:
            target_num = int(round((len(indices) / fps) * final_sr))
            cur = wav.size(0)
            if cur < target_num:
                pad = wav.new_zeros((target_num - cur,))
                wav = torch.cat([wav, pad], dim=0)
            elif cur > target_num:
                wav = wav[:target_num]

        return wav, ori_sr, ori_ch, final_sr

    def transform(self, results: Dict) -> Dict:
        video_path = Path(results.pop(self.video_path_key))
        if video_path.is_dir():
            video, _, indices = self._load_video_folder(video_path)
            fps = results.get("video_metadata", {}).get("fps")
            assert fps is not None, f"Missing fps for folder {video_path}"
        else:
            video, fps, indices = self._load_video_file(video_path)

        results["video"] = video  # [T, C, H, W]
        results["video_metadata"] = dict(
            video_path=str(video_path),
            num_frames=video.shape[0],
            fps=fps,
            duration=video.shape[0] / fps,
            height=video.shape[-2],
            width=video.shape[-1],
            ori_height=video.shape[-2],
            ori_width=video.shape[-1],
            frame_indices=tuple(indices),
        )

        heatmap_path = results.pop(self.heatmap_path_key, None)
        if heatmap_path is not None:
            heatmap = self._load_heatmap_data(heatmap_path, indices)
            results["heatmap"] = heatmap
            results["heatmap_metadata"] = dict(
                heatmap_path=heatmap_path,
                num_frames=heatmap.shape[0],
                num_joints=heatmap.shape[1],
                height=heatmap.shape[2],
                width=heatmap.shape[3],
            )

        if not self.video_only:
            if self.audio_path_key:
                audio_path = results.pop(self.audio_path_key, None)
                if audio_path is None:
                    audio_path = video_path
            else:
                audio_path = video_path
            wav, ori_sr, ori_ch, sr = self._load_audio_data(audio_path, indices, fps)

            results["audio"] = wav  # [N]
            results["audio_metadata"] = dict(
                audio_path=audio_path,
                ori_sr=ori_sr,
                ori_channels=ori_ch,
                sr=sr,
                num_samples=wav.numel(),
                duration=wav.numel() / sr,
            )

        return results