from typing import Dict, List, Optional, Tuple, Union
import torch
import sys
import os
import torch.nn.functional as F
from torchvision.transforms.functional import InterpolationMode
from mmcv import BaseTransform

sys.path.append(os.curdir)

from mmhug.registry import TRANSFORMS


@TRANSFORMS.register_module()
class ResizeVideo(BaseTransform):
    """
    - keep_ratio=True 时：等比缩放，使得输出覆盖（cover）候选框，
      即 scale = max(cand_h/orig_h, cand_w/orig_w)，保证 new_h>=cand_h 且 new_w>=cand_w。
    - keep_ratio=False 时：不保持比例，直接将输出拉伸到 (cand_h, cand_w)。
    """

    def __init__(
        self,
        video_keys: Union[str, List[str]] = ["video", "landmark_video", "mask_video"],
        size_candidates: Union[List[Tuple[int, int]], Tuple[int, int]] = [
            (256, 256),
            (320, 320),
            (384, 384),
            (512, 512),
            (480, 640),
            (720, 1280),
            (1080, 1920),
        ],
        keep_ratio: bool = True,
        interpolation: str = "bicubic",
    ):
        super().__init__()
        self.keep_ratio = keep_ratio
        mode_name = interpolation.upper()
        if not hasattr(InterpolationMode, mode_name):
            raise ValueError(f"Unknown interpolation mode '{interpolation}'")
        self.interpolation = getattr(InterpolationMode, mode_name)
        self.video_keys = video_keys
        self.size_candidates = size_candidates

    def transform(self, results: Dict) -> Dict:
        # 1) 找到有效的视频 tensor
        available = [k for k in self.video_keys if k in results]
        if not available:
            raise KeyError(
                f"No valid video keys in results, got {list(results.keys())}"
            )
        hws = [tuple(results[k].shape[-2:]) for k in available]
        if len({hws[0] for hws in hws}) != 1:
            raise ValueError(f"All videos must have identical HxW, got {hws}")
        orig_h, orig_w = hws[0]

        # 2) 挑选最优候选：以“短边变化率”最小为准（相对最接近原始大小的短边）
        orig_short = min(orig_h, orig_w)
        best_rate = float("inf")
        best_cand = (orig_h, orig_w)

        for cand_h, cand_w in self.size_candidates:
            rate = abs(min(cand_h, cand_w) / orig_short - 1.0)
            if rate < best_rate:
                best_rate = rate
                best_cand = (cand_h, cand_w)
        cand_h, cand_w = best_cand

        # 3) 计算最终 new_h, new_w
        if self.keep_ratio:
            # 使用 cover 缩放，保证 new_h>=cand_h 且 new_w>=cand_w
            scale_h = cand_h / orig_h
            scale_w = cand_w / orig_w
            scale = max(scale_h, scale_w)
            new_h = int(round(orig_h * scale))
            new_w = int(round(orig_w * scale))
        else:
            # 不保持比例，直接拉伸到候选尺寸
            new_h, new_w = cand_h, cand_w

        # 4) 对所有视频 key 做插值
        for k in available:
            vid = results[k]
            dtype = vid.dtype
            vid = F.interpolate(
                vid,
                size=(new_h, new_w),
                mode=self.interpolation.value.lower(),
                align_corners=(
                    self.interpolation
                    in (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
                ),
            ).to(dtype)
            results[k] = vid

        # 5) 更新 metadata
        results.setdefault("video_metadata", {})
        results["video_metadata"].update({"height": new_h, "width": new_w})
        return results


# … (all your imports and the ResizeVideo class definition go here) …

if __name__ == "__main__":
    from torchvision.io import read_video, write_video

    # 1. 读取 demo 视频
    video_path = "demo_assets/___OJkS9RK0_0.mp4"
    video, _, _ = read_video(video_path, pts_unit="sec")  # [T, H, W, C]
    video = video.permute(0, 3, 1, 2).contiguous()  # [T, C, H, W]

    # 2. 包装到 results
    results = {"video": video.clone()}
    transform = ResizeVideo(
        video_keys=["video", "ref_img"],
        size_candidates=[(256, 256)],
        keep_ratio=True,
    )

    resized = transform(results)
    os.makedirs("output_videos", exist_ok=True)
    save_path = os.path.join("output_videos", "resized_demo.mp4")
    # 将 Tensor 转成 [T, H, W, C] 并写出
    video_to_save = resized["video"].permute(0, 2, 3, 1).cpu()
    write_video(
        save_path,
        video_to_save,
        fps=25,  # 根据原视频或需求设置帧率
        video_codec="libx264",
    )
    print(f"Resized video saved to: {save_path}")
