from email.mime import audio
import os
from pathlib import Path
import random
from typing import Dict, List, Optional, Tuple, Union

import decord
import numpy as np
import torch
import torchaudio
from torchaudio.transforms import Resample
from mmcv import BaseTransform
import imageio.v3 as iio
from ..utils.kpt2face import check_face
from mmhug.datasets.utils.kpt_classes_and_palettes import GOLIATH_KEYPOINTS
from mmhug.registry import TRANSFORMS

IMG_POSTFIX = [".jpg", ".jpeg", ".png", ".bmp"]


@TRANSFORMS.register_module()
class LoadVideoAudioWithKeypointSegment(BaseTransform):
    def __init__(
        self,
        video_path_key: str = "video_path",
        keypoint_path_key: str = "keypoint_path",
        audio_path_key: Optional[str] = "audio_path",
        max_num_frames: Optional[int] = None,
        sampling_rate: int = 16000,
        segment_rule: Optional[str] = None,
        video_only: bool = False,
        strict_length: bool = True,
        body_part: str = "full",  # full, body, face, head
        # full: 308 keypoints
        # body:
    ) -> None:
        super().__init__()
        self.video_path_key = video_path_key
        self.audio_path_key = audio_path_key
        self.keypoint_path_key = keypoint_path_key

        self.max_num_frames = max_num_frames
        self.sampling_rate = sampling_rate
        self.video_only = video_only
        # If strict_length is True, throw error if the video length is less than max_num_frames
        self.strict_length = strict_length
        if segment_rule is None:
            assert (
                max_num_frames is None or max_num_frames < 0
            ), "If segment_rule is None, max_num_frames must be None or negative"
        else:
            assert (
                max_num_frames is not None and max_num_frames >= 0
            ), "If you set segment_rule, max_num_frames must be a non-negative int"
        assert segment_rule in [
            None,
            "random",
            "center",
            "head",
            "tail",
        ], f"segment_rule must be None or one of 'random','center','head','tail'; got {segment_rule}"
        self.segment_rule = segment_rule
        self.body_part = body_part
        assert body_part in [
            "full",
            "body",
            "face",
            "head",
        ], f"body_part must be one of 'full', 'body', 'face', 'head'; got {body_part}"

    def _sample_indices(self, total: int) -> List[int]:
        if self.segment_rule is None:
            return list(range(total))
        F = min(self.max_num_frames, total)
        if self.max_num_frames > total:
            if self.strict_length:
                raise ValueError(
                    f"The video is {total} frames long, but target {self.max_num_frames} frames."
                )
        start = {
            "random": random.randint(0, total - F),
            "center": (total - F) // 2,
            "head": 0,
            "tail": total - F,
        }[self.segment_rule]
        return list(range(start, start + F))

    def _load_video_file(
        self, filepath: Union[str, Path]
    ) -> Tuple[torch.Tensor, float, List[int]]:
        vr = decord.VideoReader(str(filepath))
        fps = int(vr.get_avg_fps() + 0.5)
        idxs = self._sample_indices(len(vr))
        frames = vr.get_batch(idxs).asnumpy()  # (len, H, W, 3)
        # convert once to torch and permute
        return (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(torch.float32),
            fps,
            idxs,
        )

    def _load_video_folder(
        self, folder: Union[str, Path]
    ) -> Tuple[torch.Tensor, None, List[int]]:
        folder = Path(folder)
        files = sorted(p for p in folder.iterdir() if p.suffix.lower() in IMG_POSTFIX)
        idxs = self._sample_indices(len(files))
        tensors = []
        for i in idxs:
            img = iio.v3.imread(str(files[i]))  # HWC
            tensors.append(torch.from_numpy(img).permute(2, 0, 1).to(torch.float32))
        return torch.stack(tensors, dim=0), None, idxs

    def _load_keypoint_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
    ):
        # T K 3
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"Keypoint file not found: {filepath}")
        keypoint = torch.from_numpy(np.load(filepath))
        keypoint = keypoint[indices]

        check_face(keypoint)

        T, K, C = keypoint.shape
        assert C >= 2, f"Expect last dim >=2 (x,y[,score]), got {C}"

        # ---- 2) 构造索引集合：优先用官方名称；否则用退化分段 ----
        # 2.1 尝试导入官方名称（Sapiens / GOLIATH 308 关键点名称）
        names = None
        if isinstance(GOLIATH_KEYPOINTS, (list, tuple)) and len(GOLIATH_KEYPOINTS) == K:
            names = [str(x).lower() for x in GOLIATH_KEYPOINTS]

        # 2.2 定义三类集合：head_all（含耳）、ear_only、face_no_ear
        if names is not None:
            # 基于名称的鲁棒筛选
            FACE_TOKENS = (
                "glabella",
                "nose",
                "nasal",
                "philtrum",
                "cupid",
                "lip",
                "labial",
                "chin",
                "menton",
                "gonion",
                "canthus",
                "brow",
                "eyebrow",
                "eyelid",
                "lash",
                "eye",
                "iris",
                "pupil",
                "cheek",
                "malar",
                "alar",
                "columella",
                "subnasale",
                "trichion",
                "forehead",
            )
            EAR_TOKENS = (
                "ear",
                "tragus",
                "antitragus",
                "intertragic",
                "helix",
                "antihelix",
                "concha",
                "scapha",
                "cymba",
                "lobule",
                "lobe",
                "crus_of_helix",
                "earlobe",
            )

            is_face_like = [any(tok in n for tok in FACE_TOKENS) for n in names]
            is_ear_like = [any(tok in n for tok in EAR_TOKENS) for n in names]

            head_all_idx = [
                i for i, (f, e) in enumerate(zip(is_face_like, is_ear_like)) if (f or e)
            ]
            ear_only_idx = [i for i, e in enumerate(is_ear_like) if e]
            face_no_ear_idx = sorted(set(head_all_idx) - set(ear_only_idx))

            # 特别：鼻/眼等基础 0–4 也应视为 head/face
            for i in (0, 1, 2, 3, 4):
                if i not in head_all_idx and i < K:
                    head_all_idx.append(i)
            head_all_idx = sorted(set(head_all_idx))
        else:
            # 退化规则（与官方键点顺序的一致性假设）：
            # 0..4: 鼻/眼/耳（基础）；5..69: 躯干四肢与手（body）
            # >=70: 头部密集关键点（面+耳）
            head_all_idx = list(range(min(5, K))) + list(range(min(70, K), K))
            # 粗略耳朵细分块（基础耳点 3,4；耳廓密集点通常在末段）
            ear_only_idx = [i for i in (3, 4) if i < K]
            # 若 K 较大（GOLIATH 308），尾段包含耳廓密集点；经验上 ~280: 308
            if K >= 281:
                ear_only_idx += list(range(280, K))
            ear_only_idx = sorted(set(ear_only_idx))
            face_no_ear_idx = sorted(set(head_all_idx) - set(ear_only_idx))

        # body = 非 head_all
        body_idx = [i for i in range(K) if i not in set(head_all_idx)]

        # ---- 3) 根据 self.body_part 选择 ----
        part = self.body_part
        if part == "full":
            sel = list(range(K))
        elif part == "body":
            sel = body_idx
        elif part == "face":
            sel = face_no_ear_idx
        elif part == "head":
            sel = head_all_idx
        else:
            raise ValueError(
                "body_part must be one of 'full', 'body', 'face', 'head'; "
                f"got {part}"
            )

        # 避免空集合导致后续训练/评估崩溃
        if len(sel) == 0:
            raise RuntimeError(
                f"Selected index set for body_part='{part}' is empty. "
                "Please check keypoint name list or fallback ranges."
            )

        # ---- 4) 返回裁剪后的 (T, K_sel, 3) ----
        return keypoint[:, sel, :]

    def _load_audio_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
        fps: Optional[float],
    ) -> Tuple[torch.Tensor, int, int, int]:
        """
        1) Read native sample_rate & channel count via torchaudio.info
        2) Compute exact original-sample offset & count for the segment
        3) Load only that slice with torchaudio.load(frame_offset, num_frames)
        4) **Pad** raw wav2d to exactly ori_num samples
        5) Mixdown (if stereo) and Resample
        6) Pad/trim final wav to exactly target_num samples
        """
        # --- 1) native metadata ---
        info = torchaudio.info(str(filepath))
        ori_sr = info.sample_rate
        ori_ch = info.num_channels

        # --- 2) compute offset & count ---
        if fps is None or not indices:
            # fallback to full file
            start_offset = 0
            ori_num = None
        else:
            num_frames = len(indices)
            duration_sec = num_frames / fps
            start_sec = indices[0] / fps

            start_offset = int(round(start_sec * ori_sr))
            ori_num = int(round(duration_sec * ori_sr))

        # --- 3) load slice ---
        if ori_num is None:
            wav2d, _ = torchaudio.load(str(filepath))  # [C, N_full]
            ori_num = wav2d.size(1)
        else:
            wav2d, _ = torchaudio.load(
                str(filepath),
                frame_offset=start_offset,
                num_frames=ori_num,
            )  # [C, ≤ori_num]

        # --- 4) pad raw wav2d so both channels have length = ori_num ---
        cur_len = wav2d.size(1)
        if cur_len < ori_num:
            pad_amt = ori_num - cur_len
            pad = wav2d.new_zeros((ori_ch, pad_amt))
            wav2d = torch.cat([wav2d, pad], dim=1)

        # --- 5) mixdown to mono & resample ---
        if ori_ch == 2:
            wav = wav2d[0]
            # L, R = wav2d[0], wav2d[1]
            # wav = torch.sqrt((L.pow(2) + R.pow(2)) / 2.0)
        else:
            wav = wav2d.squeeze(0)

        final_sr = ori_sr
        if ori_sr != self.sampling_rate:
            resampler = Resample(orig_freq=ori_sr, new_freq=self.sampling_rate)
            wav = resampler(wav.unsqueeze(0)).squeeze(0)
            final_sr = self.sampling_rate

        # --- 6) pad/trim final to exact target ---
        if fps is not None and indices:
            target_num = int(round((len(indices) / fps) * final_sr))
            cur = wav.size(0)
            if cur < target_num:
                pad = wav.new_zeros((target_num - cur,))
                wav = torch.cat([wav, pad], dim=0)
            elif cur > target_num:
                wav = wav[:target_num]

        return wav, ori_sr, ori_ch, final_sr

    def transform(self, results: Dict) -> Dict:
        # load video
        video_path = Path(results.pop(self.video_path_key))
        if video_path.is_dir():
            video, _, indices = self._load_video_folder(video_path)
            fps = results.get("video_metadata", {}).get("fps")
            assert fps is not None, f"Missing fps for folder {video_path}"
        else:
            video, fps, indices = self._load_video_file(video_path)

        results["video"] = video  # [T, C, H, W]
        results["video_metadata"] = dict(
            video_path=str(video_path),
            num_frames=video.shape[0],
            fps=fps,
            duration=video.shape[0] / fps,
            height=video.shape[-2],
            width=video.shape[-1],
            ori_height=video.shape[-2],
            ori_width=video.shape[-1],
            frame_indices=tuple(indices),
        )

        # load keypoint
        keypoint_path = results.pop(self.keypoint_path_key, None)
        if keypoint_path is not None:
            keypoint = self._load_keypoint_data(keypoint_path, indices=indices)
            results["keypoint"] = keypoint
            results["keypoint_metadata"] = dict(
                keypoint_path=keypoint_path,
                num_frames=keypoint.shape[0],
                num_joints=keypoint.shape[1],
                body_part=self.body_part,
                have_score=keypoint.shape[-1] == 3,
            )

        # load audio
        if not self.video_only:
            if self.audio_path_key:
                audio_path = results.pop(self.audio_path_key, None)
                if audio_path is None:
                    audio_path = video_path
            else:
                audio_path = video_path
            wav, ori_sr, ori_ch, sr = self._load_audio_data(audio_path, indices, fps)

            results["audio"] = wav  # [N]
            results["audio_metadata"] = dict(
                audio_path=audio_path,
                ori_sr=ori_sr,
                ori_channels=ori_ch,
                sr=sr,
                num_samples=wav.numel(),
                duration=wav.numel() / sr,
            )

        return results


if __name__ == "__main__":

    pipeline = [
        dict(
            type="LoadVideoAudioWithKeypointSegment",
            video_path_key="video_path",
            audio_path_key="audio_path",
            keypoint_path_key="keypoint_path",
            max_num_frames=16,
            sampling_rate=16000,
            segment_rule="random",
            video_only=True,
            strict_length=True,
        ),
        dict(
            type="ResizeVideo",
            video_keys=["video"],
            size_candidates=[(512, 512)],
            keep_ratio=True,
        ),
        dict(type="CenterCropVideo", video_keys=["video"], crop_size=(512, 512)),
        dict(
            type="NormalizeVideo",  # w.r.t dinov2
            video_keys=["video", "ref_img"],
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5],
        ),
    ]

    dataset = dict(
        type="TextVideoAudioKeypointDataset",
        data_dir="data/",
        anno_file="data/RAVDNESS/annotations/test_anno.json",
        pipeline=pipeline,
        refetch=True,
    )

    
