from email.mime import audio
from pathlib import Path
import random
from typing import Dict, List, Optional, Sequence, Tuple, Union

import decord
import torch
import torchaudio
from torchaudio.transforms import Resample
from mmcv import BaseTransform, imread
from mmengine import print_log
import imageio.v3 as iio
from mmhug.registry import TRANSFORMS
from mmhug.datasets.utils.key_define import APPEARANCE, ACTION, EMOTION

IMG_POSTFIX = [".jpg", ".jpeg", ".png", ".bmp"]

EMOTION_MAPPING = {
    "neutral": 0,
    "anger": 1,
    "angry": 1,
    "contempt": 2,
    "disgust": 3,  # celebv-hq use disgust
    "disgusted": 3,  # MEAD use disgusted,
    "fear": 4,
    "happiness": 5,
    "happy": 5,  # MEAD
    "sadness": 6,
    "sad": 6,  # MEAD
    "surprise": 7,
    "surprised": 7,  # MEAD
}

# celebvhq
APPEARANCE_MAPPING = {
    "blurry": 0,
    "male": 1,
    "young": 2,
    "chubby": 3,
    "pale_skin": 4,
    "rosy_cheeks": 5,
    "oval_face": 6,
    "receding_hairline": 7,
    "bald": 8,
    "bangs": 9,
    "black_hair": 10,
    "blonde_hair": 11,
    "gray_hair": 12,
    "brown_hair": 13,
    "straight_hair": 14,
    "wavy_hair": 15,
    "long_hair": 16,
    "arched_eyebrows": 17,
    "bushy_eyebrows": 18,
    "bags_under_eyes": 19,
    "eyeglasses": 20,
    "sunglasses": 21,
    "narrow_eyes": 22,
    "big_nose": 23,
    "pointy_nose": 24,
    "high_cheekbones": 25,
    "big_lips": 26,
    "double_chin": 27,
    "no_beard": 28,
    "5_o_clock_shadow": 29,
    "goatee": 30,
    "mustache": 31,
    "sideburns": 32,
    "heavy_makeup": 33,
    "wearing_earrings": 34,
    "wearing_hat": 35,
    "wearing_lipstick": 36,
    "wearing_necklace": 37,
    "wearing_necktie": 38,
    "wearing_mask": 39,
}

# celebvhq
ACTION_MAPPING = {
    "blow": 0,
    "chew": 1,
    "close_eyes": 2,
    "cough": 3,
    "cry": 4,
    "drink": 5,
    "eat": 6,
    "frown": 7,
    "gaze": 8,
    "glare": 9,
    "head_wagging": 10,
    "kiss": 11,
    "laugh": 12,
    "listen_to_music": 13,
    "look_around": 14,
    "make_a_face": 15,
    "nod": 16,
    "play_instrument": 17,
    "read": 18,
    "shake_head": 19,
    "shout": 20,
    "sigh": 21,
    "sing": 22,
    "sleep": 23,
    "smile": 24,
    "smoke": 25,
    "sneer": 26,
    "sneeze": 27,
    "sniff": 28,
    "talk": 29,
    "turn": 30,
    "weep": 31,
    "whisper": 32,
    "wink": 33,
    "yawn": 34,
}


def emotion_to_one_hot(
    labels: Union[str, Sequence[str]],
    mapping: dict = EMOTION_MAPPING,
    num_classes: Optional[int] = None,
    *,
    unknown: str = "raise",  # "raise" | "zero" | "ignore"
    framework: str = "torch",  # "list" | "numpy" | "torch"
):
    """
    将字符串情绪标签映射为 one-hot 向量/矩阵。
    - labels: 单个字符串或字符串序列
    - mapping: 字符串 -> 类别索引 的字典
    - num_classes: 类别数；默认= max(mapping.values())+1
    - unknown:
        - "raise": 遇到未知标签抛 KeyError
        - "zero" : 输出全零 one-hot
        - "ignore": 丢弃该条样本（仅对批量输入生效）
    - framework: 返回类型；"list"/"numpy"/"torch"
    """
    if num_classes is None:
        num_classes = max(mapping.values()) + 1

    def _encode_one(s: str):
        k = s.strip().lower()
        idx = mapping.get(k, None)
        if idx is None:
            if unknown == "raise":
                raise KeyError(f"Unknown emotion label: {s!r}")
            elif unknown == "ignore":
                return None
            elif unknown == "zero":
                idx = None

        if framework == "torch":
            import torch

            v = torch.zeros(num_classes, dtype=torch.float32)
            if idx is not None:
                v[idx] = 1.0
            return v
        elif framework == "numpy":
            import numpy as np

            v = np.zeros(num_classes, dtype=np.float32)
            if idx is not None:
                v[idx] = 1.0
            return v
        else:
            v = [0] * num_classes
            if idx is not None:
                v[idx] = 1
            return v

    # 单条
    if isinstance(labels, str):
        return _encode_one(labels)

    # 批量
    outs = [_encode_one(s) for s in labels]
    if unknown == "ignore":
        outs = [o for o in outs if o is not None]

    if framework == "torch":
        import torch

        return (
            torch.stack(outs)
            if outs and isinstance(outs[0], torch.Tensor)
            else torch.empty((0, num_classes))
        )
    elif framework == "numpy":
        import numpy as np

        return (
            np.stack(outs)
            if outs and hasattr(outs[0], "shape")
            else np.zeros((0, num_classes), dtype=np.float32)
        )
    else:
        return outs


@TRANSFORMS.register_module()
class LoadVideoWithLabelSegment(BaseTransform):
    def __init__(
        self,
        video_path_key: str = "video_path",
        audio_path_key: Optional[str] = "audio_path",
        label_key: str = APPEARANCE,  # [APPEARANCE, "action", "emotion"]
        max_num_frames: Optional[int] = 16,
        sampling_rate: int = 16000,
        segment_rule: Optional[str] = None,
        video_only: bool = False,
        strict_length: bool = True,
        assert_fps: Optional[int] = 25,
    ) -> None:
        super().__init__()
        self.label_key = label_key
        self.video_path_key = video_path_key
        self.audio_path_key = audio_path_key
        self.max_num_frames = max_num_frames
        self.sampling_rate = sampling_rate
        self.video_only = video_only
        # If strict_length is True, throw error if the video length is less than max_num_frames
        self.strict_length = strict_length
        if segment_rule is None:
            assert (
                max_num_frames is None or max_num_frames < 0
            ), "If segment_rule is None, max_num_frames must be None or negative"
        else:
            assert (
                max_num_frames is not None and max_num_frames >= 0
            ), "If you set segment_rule, max_num_frames must be a non-negative int"
        assert segment_rule in [
            None,
            "random",
            "center",
            "head",
            "tail",
        ], f"segment_rule must be None or one of 'random','center','head','tail'; got {segment_rule}"
        self.segment_rule = segment_rule
        self.label_key = label_key
        assert self.label_key in [
            APPEARANCE,
            "action",
            "emotion",
        ], f"label_key must be one of 'appearance', 'action', 'emotion'; got {self.label_key}"
        self.assert_fps = assert_fps

    def _sample_indices(
        self, total: int, start_frame: int = 0, end_frame: int = None
    ) -> List[int]:
        """
        在给定的帧窗口 [start_frame, end_frame) 内采样，返回全局帧索引列表。
        - total: 视频总帧数
        - start_frame/end_frame: 帧号（已是帧索引，非秒）
        规则：
        - segment_rule=None 时：返回窗口内所有帧（升序）
        - 其它规则：在窗口内按 max_num_frames 进行 random/center/head/tail 采样
        - strict_length=True 且窗口长度 < max_num_frames 时抛错
        """
        if end_frame is None:
            end_frame = total

        # 规范化与裁剪
        start_frame = max(0, min(start_frame, total))
        end_frame = max(0, min(end_frame, total))
        if end_frame < start_frame:
            end_frame = start_frame

        window_len = end_frame - start_frame
        if window_len == 0:
            return []

        # 不做片段采样：直接返回窗口内所有帧
        if self.segment_rule is None:
            return list(range(start_frame, end_frame))

        # 目标帧数与窗口约束
        F_target = self.max_num_frames
        F = min(F_target, window_len)
        if F_target > window_len and getattr(self, "strict_length", False):
            raise ValueError(
                f"Window [{start_frame}, {end_frame}) has only {window_len} frames, "
                f"but target {F_target} frames."
            )

        # 在窗口内确定起点偏移
        offset_max = max(0, window_len - F)
        rule = self.segment_rule
        if rule == "random":
            offset = random.randint(0, offset_max)  # 含端点
        elif rule == "center":
            offset = offset_max // 2
        elif rule == "head":
            offset = 0
        elif rule == "tail":
            offset = offset_max
        else:
            raise ValueError(f"Unknown segment_rule: {rule}")

        start = start_frame + offset
        return list(range(start, start + F))

    def _load_video_file(
        self, filepath: Union[str, Path], start: float = None, end: float = None
    ) -> Tuple[torch.Tensor, float, List[int]]:
        """
        Args:
            filepath (Union[str, Path]): Path to the video file.
            start (float, optional): Start time in seconds. Defaults to None. If given, the video will be cropped from this time.
            end (float, optional): End time in seconds. Defaults to None. If given, the video will be cropped to this time.

        Returns:
            Tuple[torch.Tensor, float, List[int]]: A tuple containing the video tensor, fps, and indices.
            If start and end are given, the indices will still be given as if the video is not cropped.
        """
        vr = decord.VideoReader(str(filepath))
        fps = int(vr.get_avg_fps() + 0.5)
        if self.assert_fps is not None:
            assert (
                fps == self.assert_fps
            ), f"{filepath} has fps {fps} != {self.assert_fps}"
        start_frame = int(start * fps) if start is not None else 0
        end_frame = int(end * fps) if end is not None else len(vr)
        if end_frame - start_frame < self.max_num_frames:
            raise ValueError(
                f"Video {filepath} has only {end_frame - start_frame} frames, "
                f"but max_num_frames is {self.max_num_frames}"
            )
        idxs = self._sample_indices(
            len(vr), start_frame=start_frame, end_frame=end_frame
        )
        frames = vr.get_batch(idxs).asnumpy()  # (len, H, W, 3)
        # convert once to torch and permute
        if self.strict_length and frames.shape[0] < self.max_num_frames:
            raise ValueError(
                f"Video {filepath} has only {frames.shape[0]} frames, "
                f"but max_num_frames is {self.max_num_frames}"
            )
        return (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(torch.float32),
            fps,
            idxs,
        )

    def _load_video_folder(
        self, folder: Union[str, Path], fps, start: float = None, end: float = None
    ) -> Tuple[torch.Tensor, None, List[int]]:
        folder = Path(folder)
        files = sorted(p for p in folder.iterdir() if p.suffix.lower() in IMG_POSTFIX)
        start_frame = int(start * fps) if start is not None else 0
        end_frame = int(end * fps) if end is not None else len(files)
        idxs = self._sample_indices(
            len(files), start_frame=start_frame, end_frame=end_frame
        )
        tensors = []
        for i in idxs:
            img = iio.v3.imread(str(files[i]))  # HWC
            tensors.append(torch.from_numpy(img).permute(2, 0, 1).to(torch.float32))
        video = torch.stack(tensors, dim=0)
        if self.strict_length and video.shape[0] < self.max_num_frames:
            raise ValueError(
                f"Video {folder} has only {video.shape[0]} frames, "
                f"but max_num_frames is {self.max_num_frames}"
            )
        return video, fps, idxs

    def _load_audio_data(
        self,
        filepath: Union[str, Path],
        indices: List[int],
        fps: Optional[float],
    ) -> Tuple[torch.Tensor, int, int, int]:
        """
        1) Read native sample_rate & channel count via torchaudio.info
        2) Compute exact original-sample offset & count for the segment
        3) Load only that slice with torchaudio.load(frame_offset, num_frames)
        4) **Pad** raw wav2d to exactly ori_num samples
        5) Mixdown (if stereo) and Resample
        6) Pad/trim final wav to exactly target_num samples
        """
        # --- 1) native metadata ---
        info = torchaudio.info(str(filepath))
        ori_sr = info.sample_rate
        ori_ch = info.num_channels

        # --- 2) compute offset & count ---
        if fps is None or not indices:
            # fallback to full file
            start_offset = 0
            ori_num = None
        else:
            num_frames = len(indices)
            duration_sec = num_frames / fps
            start_sec = indices[0] / fps

            start_offset = int(round(start_sec * ori_sr))
            ori_num = int(round(duration_sec * ori_sr))

        # --- 3) load slice ---
        if ori_num is None:
            wav2d, _ = torchaudio.load(str(filepath))  # [C, N_full]
            ori_num = wav2d.size(1)
        else:
            wav2d, _ = torchaudio.load(
                str(filepath),
                frame_offset=start_offset,
                num_frames=ori_num,
            )  # [C, ≤ori_num]

        # --- 4) pad raw wav2d so both channels have length = ori_num ---
        cur_len = wav2d.size(1)
        if cur_len < ori_num:
            pad_amt = ori_num - cur_len
            pad = wav2d.new_zeros((ori_ch, pad_amt))
            wav2d = torch.cat([wav2d, pad], dim=1)

        # --- 5) mixdown to mono & resample ---
        if ori_ch == 2:
            wav = wav2d[0]
            # L, R = wav2d[0], wav2d[1]
            # wav = torch.sqrt((L.pow(2) + R.pow(2)) / 2.0)
        else:
            wav = wav2d.squeeze(0)

        final_sr = ori_sr
        if ori_sr != self.sampling_rate:
            resampler = Resample(orig_freq=ori_sr, new_freq=self.sampling_rate)
            wav = resampler(wav.unsqueeze(0)).squeeze(0)
            final_sr = self.sampling_rate

        # --- 6) pad/trim final to exact target ---
        if fps is not None and indices:
            target_num = int(round((len(indices) / fps) * final_sr))
            cur = wav.size(0)
            if cur < target_num:
                pad = wav.new_zeros((target_num - cur,))
                wav = torch.cat([wav, pad], dim=0)
            elif cur > target_num:
                wav = wav[:target_num]

        return wav, ori_sr, ori_ch, final_sr

    def load_label_info(self, label_key: str, raw_label: Union[str, Dict, List]):
        start = None
        end = None
        if label_key == APPEARANCE:
            label = torch.tensor(raw_label, dtype=torch.long)
        elif label_key == ACTION:
            label = torch.tensor(raw_label, dtype=torch.long)
        else:
            # emotion
            if isinstance(raw_label, str):
                # MEAD style
                label = emotion_to_one_hot(raw_label)
            elif isinstance(raw_label, dict):
                multi_emotion = raw_label["sep_flag"]
                if multi_emotion:
                    raw_label = random.choice(raw_label["labels"])
                    label = emotion_to_one_hot(raw_label["emotion"])
                    start = raw_label["start_sec"]
                    end = raw_label["end_sec"]
                    if start >= end:
                        raise ValueError(f"Check emotion label for {raw_label}")
                else:
                    label = emotion_to_one_hot(raw_label["labels"])
            else:
                raise ValueError(f"Unknown label type {label}")
        return label, start, end

    def transform(self, results: Dict) -> Dict:
        label, start, end = self.load_label_info(
            self.label_key, results.pop(self.label_key)
        )
        results.pop(APPEARANCE, None)
        results.pop(ACTION, None)
        results.pop(EMOTION, None)

        results[self.label_key] = label

        video_path = Path(results.pop(self.video_path_key))
        if video_path.is_dir():
            fps = results.get("video_metadata", {}).get("fps")
            assert fps is not None, f"Missing fps for folder {video_path}"
            video, _, indices = self._load_video_folder(video_path, fps, start, end)

        else:
            video, fps, indices = self._load_video_file(video_path, start, end)

        results["video"] = video  # [T, C, H, W]
        results["video_metadata"] = dict(
            video_path=str(video_path),
            num_frames=video.shape[0],
            fps=fps,
            duration=video.shape[0] / fps,
            height=video.shape[-2],
            width=video.shape[-1],
            ori_height=video.shape[-2],
            ori_width=video.shape[-1],
            frame_indices=tuple(indices),
        )

        if not self.video_only:
            if self.audio_path_key:
                audio_path = results.pop(self.audio_path_key, None)
                if audio_path is None:
                    audio_path = video_path
            else:
                audio_path = video_path
            wav, ori_sr, ori_ch, sr = self._load_audio_data(audio_path, indices, fps)

            results["audio"] = wav  # [N]
            results["audio_metadata"] = dict(
                audio_path=audio_path,
                ori_sr=ori_sr,
                ori_channels=ori_ch,
                sr=sr,
                num_samples=wav.numel(),
                duration=wav.numel() / sr,
            )
        else:
            results.pop("audio", None)
            results.pop("audio_metadata", None)
            results.pop(self.audio_path_key, None)

        return results
