from math import e
import random
from pathlib import Path
import subprocess
from typing import List, Optional, Tuple, Union

import decord
from einops import rearrange
import torch
import torchaudio
from torchaudio.transforms import Resample
import imageio.v3 as iio

from mmcv import TRANSFORMS, BaseTransform  # adjust import to your project structure
import sys
import os
import numpy as np
from ..utils.kpt2face import check_face
from mmhug.datasets.utils.kpt_classes_and_palettes import GOLIATH_KEYPOINTS

from mmhug.datasets.transforms.load_video import (
    IMG_POSTFIX,
)  # list of allowed image file extensions


@TRANSFORMS.register_module()
class LoadVideoAudioSegmentWithKeypointRef(BaseTransform):
    """
    Load a segment of video frames (and optionally audio) with support for:
      1) cropping the frame count to be a multiple of a given integer;
      2) optionally loading a single “reference” image frame outside or just before the segment.

    Args:
        video_path_key (str):
            Key in `results` dict where the video file or folder path is stored.
        audio_path_key (Optional[str]):
            Key in `results` dict for audio path. If None, `video_path_key` is used.
        max_num_frames (Optional[int]):
            Maximum number of frames to crop. If None or negative, no cropping is applied.
              If strict_length is False, the resulted video is determined by min(len(video), max_num_frames)
        sampling_rate (int):
            Target audio sampling rate in Hz.
        segment_rule (Optional[str]):
            One of {None, "random", "center", "head", "tail"}:
              - None: keep all frames (no cropping);
              - "random": select a random contiguous block;
              - "center": center block;
              - "head": first block;
              - "tail": last block.
        video_only (bool):
            If True, do not attempt to load audio.
        strict_length (bool):
            If True, raise an error when video length < `max_num_frames`.
            Otherwise, just use full length.
        frame_multiple (int):
            If >1, after determining target length F = min(max_num_frames, total_frames),
            round F down to the nearest multiple of this value. If <=1, ignore.
        use_ref_img (bool):
            Whether to also load one reference image frame.
        ref_img_rule (Optional[str]):
            If `use_ref_img` is True, must be one of:
              - "random_out": pick a random frame outside the selected segment;
              - "prev": pick the frame immediately preceding the segment start.
    """

    def __init__(
        self,
        # keys for loading
        video_path_key: str = "video_path",
        audio_path_key: Optional[str] = "audio_path",
        keypoint_path_key: str = "keypoint_path",
        # filter rule, videos with length in [filter_min_num_frames, filter_max_num_frames] will be kept
        filter_max_num_frames: Optional[int] = 1e7,
        filter_min_num_frames: Optional[int] = 0,
        # segment rule, the video will be segmented into clips with no more than segment_max_num_frames frames.
        segment_num_frames: Optional[int] = None,
        segment_rule: Optional[str] = None,
        frame_multiple: int = 8,
        frame_multiple_add: int = 1,
        # audio sampling rate
        sampling_rate: int = 16000,
        # whether load audio
        video_only: bool = False,
        # The cropped video length will be rounded down to the nearest multiple of this value.
        use_ref_img: bool = False,
        num_ref_img: int = 1,
        ref_img_rule: Optional[str] = None,
        # keypoint
        body_part: str = "full",  # full, body, face, head
        assert_fps: Optional[int] = 25,
    ) -> None:
        super().__init__()
        self.video_path_key = video_path_key
        self.audio_path_key = audio_path_key
        self.keypoint_path_key = keypoint_path_key

        self.filter_max_num_frames = filter_max_num_frames
        self.filter_min_num_frames = filter_min_num_frames

        self.segment_num_frames = segment_num_frames
        self.sampling_rate = sampling_rate
        self.video_only = video_only

        # Frame-multiple cropping parameter
        self.frame_multiple = frame_multiple
        self.frame_multiple_add = frame_multiple_add

        # Reference-image logic
        self.num_ref_img = num_ref_img
        self.use_ref_img = use_ref_img
        if use_ref_img:
            assert ref_img_rule in (
                "random",
                "random_video",
                "random_out",
                "prev",
            ), "ref_img_rule must be 'random', 'random_out' or 'prev' when use_ref_img=True"
        self.ref_img_rule = ref_img_rule

        # Validate segment_rule vs. max_num_frames
        if segment_rule is None:
            assert (
                self.segment_num_frames is None or self.segment_num_frames < 0
            ), "If segment_rule is None, the entire video will be used as a segment, so please set segment_num_frames to None or negative"
        else:
            assert (
                self.segment_num_frames is not None and self.segment_num_frames >= 0
            ), "If you set segment_rule, segment_num_frames must be a non-negative int"

        assert segment_rule in [
            None,
            "random",
            "center",
            "head",
            "tail",
        ], f"segment_rule must be None or one of 'random','center','head','tail'; got {segment_rule}"
        self.segment_rule = segment_rule
        self.assert_fps = assert_fps

        self.body_part = body_part
        assert body_part in [
            "full",
            "body",
            "face",
            "head",
        ], f"body_part must be one of 'full', 'body', 'face', 'head'; got {body_part}"

    def _sample_indices(self, total: int) -> List[int]:
        """
        Determine which frame indices to sample from a video of length `total`.

        Ensures that, if use_ref_img=True, at least one frame outside the segment remains.
        If segment would cover all frames, iteratively reduce F by frame_multiple until
        a candidate remains or fail.
        """
        # Case: no cropping requested
        if self.segment_rule is None:
            if self.use_ref_img and total <= 1:
                raise ValueError(
                    f"Video too short ({total} frames) to supply a reference image."
                )
            return list(range(total))

        # Base target frame count
        F = min(self.segment_num_frames, total)

        # Enforce strict length if requested
        if total < self.filter_min_num_frames:
            raise ValueError(
                f"The video is too short({total} frames), lower than {self.filter_min_num_frames} limit."
            )
        elif total >= self.filter_max_num_frames:
            raise ValueError(
                f"The video is too long ({total} frames), over than {self.filter_max_num_frames} limit"
            )

        # Apply multiple-rounding
        if self.frame_multiple > 1:
            ori_len = F
            F = (
                F - self.frame_multiple_add
            ) // self.frame_multiple * self.frame_multiple + self.frame_multiple_add
            assert (
                F >= 0 and F >= self.filter_min_num_frames
            ), f"Got {ori_len} frames input, cannot find a possible segment length with format {self.frame_multiple} * n + {self.frame_multiple_add} and longer than {self.filter_min_num_frames}"

        # If reference image required, ensure at least one outside-frame candidate
        if self.use_ref_img and self.ref_img_rule != "random":
            assert (
                F < total
            ), f"Cannot get an reference img outside segment, because the segment length is {F} but the entire video length is {total}"

        # Choose start index
        start = {
            "random": random.randint(0, total - F),
            "center": (total - F) // 2,
            "head": 0,
            "tail": total - F,
        }[self.segment_rule]

        return list(range(start, start + F))

    def _load_video_file(
        self, filepath: Union[str, Path]
    ) -> Tuple[torch.Tensor, float, List[int]]:
        vr = decord.VideoReader(str(filepath))
        fps = int(vr.get_avg_fps() + 0.5)
        if self.assert_fps is not None:
            assert (
                fps == self.assert_fps
            ), f"{filepath} has fps {fps} != {self.assert_fps}"
        idxs = self._sample_indices(len(vr))
        frames = vr.get_batch(idxs).asnumpy()  # (len, H, W, 3)
        # convert once to torch and permute
        return (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(torch.float32),
            fps,
            idxs,
        )

    def _load_video_folder(
        self, folder: Union[str, Path]
    ) -> Tuple[torch.Tensor, None, List[int]]:
        folder = Path(folder)
        files = sorted(p for p in folder.iterdir() if p.suffix.lower() in IMG_POSTFIX)
        total = len(files)
        indices = self._sample_indices(total)
        tensors = []
        for idx in indices:
            img = iio.imread(str(files[idx]))
            tensors.append(torch.from_numpy(img).permute(2, 0, 1).float())
        return torch.stack(tensors, dim=0), None, indices

    def _load_audio_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
        fps: Optional[float],
    ) -> Tuple[torch.Tensor, int, int, int]:
        """
        1) Read native sample_rate & channel count via torchaudio.info
        2) Compute exact original-sample offset & count for the segment
        3) Load only that slice with torchaudio.load(frame_offset, num_frames)
        4) **Pad** raw wav2d to exactly ori_num samples
        5) Mixdown (if stereo) and Resample
        6) Pad/trim final wav to exactly target_num samples
        """
        try:
            # --- 1) native metadata ---
            info = torchaudio.info(str(filepath))
            ori_sr = info.sample_rate
            ori_ch = info.num_channels

            # --- 2) compute offset & count ---
            if fps is None or not indices:
                # fallback to full file
                start_offset = 0
                ori_num = None
            else:
                num_frames = len(indices)
                duration_sec = num_frames / fps
                start_sec = indices[0] / fps

                start_offset = int(round(start_sec * ori_sr))
                ori_num = int(round(duration_sec * ori_sr))

            # --- 3) load slice ---
            if ori_num is None:
                wav2d, _ = torchaudio.load(str(filepath))  # [C, N_full]
                ori_num = wav2d.size(1)
            else:
                wav2d, _ = torchaudio.load(
                    str(filepath),
                    frame_offset=start_offset,
                    num_frames=ori_num,
                )  # [C, ≤ori_num]

            # --- 4) pad raw wav2d so both channels have length = ori_num ---
            cur_len = wav2d.size(1)
            if cur_len < ori_num:
                pad_amt = ori_num - cur_len
                pad = wav2d.new_zeros((ori_ch, pad_amt))
                wav2d = torch.cat([wav2d, pad], dim=1)

            # --- 5) mixdown to mono & resample ---
            if ori_ch == 2:
                # wav = wav2d[0]
                # L, R = wav2d[0], wav2d[1]
                # wav = torch.sqrt((L.pow(2) + R.pow(2)) / 2.0)
                wav = torch.mean(wav2d, dim=0)
            else:
                wav = wav2d.squeeze(0)

            final_sr = ori_sr
            if ori_sr != self.sampling_rate:
                resampler = Resample(orig_freq=ori_sr, new_freq=self.sampling_rate)
                wav = resampler(wav.unsqueeze(0)).squeeze(0)
                final_sr = self.sampling_rate

            # --- 6) pad/trim final to exact target ---
            if fps is not None and indices:
                target_num = int(round((len(indices) / fps) * final_sr))
                cur = wav.size(0)
                if cur < target_num:
                    pad = wav.new_zeros((target_num - cur,))
                    wav = torch.cat([wav, pad], dim=0)
                elif cur > target_num:
                    wav = wav[:target_num]
            return wav, ori_sr, ori_ch, final_sr
        except Exception as e:
            raise ValueError(f"Error loading audio {filepath}: {e}")

    def _load_keypoint_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
    ):
        # T K 3
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"Keypoint file not found: {filepath}")
        keypoint = torch.from_numpy(np.load(filepath))
        keypoint = keypoint[indices]

        check_face(keypoint)

        T, K, C = keypoint.shape
        assert C >= 2, f"Expect last dim >=2 (x,y[,score]), got {C}"

        # ---- 2) 构造索引集合：优先用官方名称；否则用退化分段 ----
        # 2.1 尝试导入官方名称（Sapiens / GOLIATH 308 关键点名称）
        names = None
        if isinstance(GOLIATH_KEYPOINTS, (list, tuple)) and len(GOLIATH_KEYPOINTS) == K:
            names = [str(x).lower() for x in GOLIATH_KEYPOINTS]

        # 2.2 定义三类集合：head_all（含耳）、ear_only、face_no_ear
        if names is not None:
            # 基于名称的鲁棒筛选
            FACE_TOKENS = (
                "glabella",
                "nose",
                "nasal",
                "philtrum",
                "cupid",
                "lip",
                "labial",
                "chin",
                "menton",
                "gonion",
                "canthus",
                "brow",
                "eyebrow",
                "eyelid",
                "lash",
                "eye",
                "iris",
                "pupil",
                "cheek",
                "malar",
                "alar",
                "columella",
                "subnasale",
                "trichion",
                "forehead",
            )
            EAR_TOKENS = (
                "ear",
                "tragus",
                "antitragus",
                "intertragic",
                "helix",
                "antihelix",
                "concha",
                "scapha",
                "cymba",
                "lobule",
                "lobe",
                "crus_of_helix",
                "earlobe",
            )

            is_face_like = [any(tok in n for tok in FACE_TOKENS) for n in names]
            is_ear_like = [any(tok in n for tok in EAR_TOKENS) for n in names]

            head_all_idx = [
                i for i, (f, e) in enumerate(zip(is_face_like, is_ear_like)) if (f or e)
            ]
            ear_only_idx = [i for i, e in enumerate(is_ear_like) if e]
            face_no_ear_idx = sorted(set(head_all_idx) - set(ear_only_idx))

            # 特别：鼻/眼等基础 0–4 也应视为 head/face
            for i in (0, 1, 2, 3, 4):
                if i not in head_all_idx and i < K:
                    head_all_idx.append(i)
            head_all_idx = sorted(set(head_all_idx))
        else:
            # 退化规则（与官方键点顺序的一致性假设）：
            # 0..4: 鼻/眼/耳（基础）；5..69: 躯干四肢与手（body）
            # >=70: 头部密集关键点（面+耳）
            head_all_idx = list(range(min(5, K))) + list(range(min(70, K), K))
            # 粗略耳朵细分块（基础耳点 3,4；耳廓密集点通常在末段）
            ear_only_idx = [i for i in (3, 4) if i < K]
            # 若 K 较大（GOLIATH 308），尾段包含耳廓密集点；经验上 ~280: 308
            if K >= 281:
                ear_only_idx += list(range(280, K))
            ear_only_idx = sorted(set(ear_only_idx))
            face_no_ear_idx = sorted(set(head_all_idx) - set(ear_only_idx))

        # body = 非 head_all
        body_idx = [i for i in range(K) if i not in set(head_all_idx)]

        # ---- 3) 根据 self.body_part 选择 ----
        part = self.body_part
        if part == "full":
            sel = list(range(K))
        elif part == "body":
            sel = body_idx
        elif part == "face":
            sel = face_no_ear_idx
        elif part == "head":
            sel = head_all_idx
        else:
            raise ValueError(
                "body_part must be one of 'full', 'body', 'face', 'head'; "
                f"got {part}"
            )

        # 避免空集合导致后续训练/评估崩溃
        if len(sel) == 0:
            raise RuntimeError(
                f"Selected index set for body_part='{part}' is empty. "
                "Please check keypoint name list or fallback ranges."
            )

        # ---- 4) 返回裁剪后的 (T, K_sel, 3) ----
        return keypoint[:, sel, :]

    def transform(self, results: dict) -> dict:
        video_path = Path(results.pop(self.video_path_key))
        if video_path.is_dir():
            video, fps, indices = self._load_video_folder(video_path)
            fps = results.get("video_metadata", {}).get("fps", fps)
            assert fps is not None, f"Missing fps for folder {video_path}"
        else:
            video, fps, indices = self._load_video_file(video_path)

        results["video"] = video
        results["video_metadata"] = dict(
            video_path=str(video_path),
            num_frames=video.shape[0],
            fps=fps,
            duration=(video.shape[0] / fps) if fps else None,
            height=video.shape[2],
            width=video.shape[3],
            frame_indices=tuple(indices),
        )

        # Load reference image if needed
        if self.use_ref_img:
            k = self.num_ref_img
            # Determine total frames for reference pool
            if video_path.is_dir():
                total = len(
                    [p for p in video_path.iterdir() if p.suffix.lower() in IMG_POSTFIX]
                )
            else:
                total = len(decord.VideoReader(str(video_path)))

            # Compute candidate set and pick ref_idx
            if self.ref_img_rule == "random_out":
                pool = set(range(total)) - set(indices)
                ref_idx: list[int] = random.choices(list(pool), k=k) if pool else None
            elif self.ref_img_rule == "random":
                pool = set(range(total))
                ref_idx: list[int] = random.choices(list(pool), k=k)
            elif self.ref_img_rule == "random_video":
                if k <= 0:
                    k = video.shape[0]
                if total <= k:
                    raise ValueError(
                        f"Not enough frames in video {video_path} for random_video ref_img_rule, total={total}, k={k}"
                    )
                start_frame = random.randint(0, total - k)
                ref_idx: list[int] = list(range(start_frame, start_frame + k))
            else:
                ref_idx: list[int] = [indices[0] - 1] if indices[0] > 0 else None

            if ref_idx is None:
                raise ValueError(
                    f"Unable to find reference frame with rule={self.ref_img_rule} in video {video_path}, dataset size={total}, segment indices range={indices}"
                )

            # Load the reference frame
            if video_path.is_dir():
                files = sorted(
                    p for p in video_path.iterdir() if p.suffix.lower() in IMG_POSTFIX
                )
                ref_imgs = [iio.imread(str(files[idx])) for idx in ref_idx]
            else:
                vr = decord.VideoReader(str(video_path))
                ref_imgs = vr.get_batch(ref_idx).asnumpy()

            ref_imgs = torch.tensor(ref_imgs).float()
            ref_imgs = rearrange(ref_imgs, "k h w c -> k c h w")
            results["ref_img"] = ref_imgs  # [1, C, H, W]
            results["video_metadata"]["ref_indice"] = tuple(ref_idx)

        if not self.video_only:
            if self.audio_path_key:
                audio_path = results.pop(self.audio_path_key, None)
                if audio_path is None:
                    audio_path = video_path
            else:
                audio_path = video_path
            wav, ori_sr, ori_ch, sr = self._load_audio_data(audio_path, indices, fps)

            results["audio"] = wav
            results["audio_metadata"] = dict(
                audio_path=audio_path,
                ori_sr=ori_sr,
                ori_channels=ori_ch,
                sr=sr,
                num_samples=wav.numel(),
                duration=wav.numel() / sr,
            )

        # load keypoint
        keypoint_path = results.pop(self.keypoint_path_key, None)
        if keypoint_path is not None:
            keypoint = self._load_keypoint_data(keypoint_path, indices=indices)
            results["keypoint"] = keypoint
            results["keypoint_metadata"] = dict(
                keypoint_path=keypoint_path,
                num_frames=keypoint.shape[0],
                num_joints=keypoint.shape[1],
                body_part=self.body_part,
                have_score=keypoint.shape[-1] == 3,
            )

        return results
