from typing import Dict, List, Optional, Tuple, Union
from mmcv import BaseTransform
from mmhug.registry import TRANSFORMS
import torch


@TRANSFORMS.register_module()
class CenterCropVideo(BaseTransform):
    def __init__(
        self,
        video_keys: Union[str, List[str]] = ["video", "landmark_video", "mask_video"],
        crop_size: Tuple[int, int] = (512, 512),
    ) -> None:
        """Crop the patch of crop_size at the center of the video. We recommend users to use this operation after ResizeVideo
        Args:
            video_keys (Union[str, List[str]]): Keys of videos to be cropped.
            crop_size (Tuple[int, int]): Size of cropped video.
        """
        super().__init__()
        self.video_keys = video_keys if isinstance(video_keys, list) else [video_keys]
        self.crop_size = (
            crop_size if isinstance(crop_size, tuple) else (crop_size, crop_size)
        )

        assert (
            isinstance(crop_size, tuple)
            and len(crop_size) == 2
            and all(isinstance(x, int) for x in crop_size)
        ), "crop_size must be a tuple of two ints"
        self.crop_size = crop_size

    def transform(self, results: Dict) -> Dict:
        """
        Args:
            results (dict): dict containing video tensor and metadata.
        Returns:
            dict: dict containing cropped video tensor and updated metadata.
        """
        ch, cw = self.crop_size

        for key in self.video_keys:
            if key not in results:
                continue

            video = results[key]
            assert isinstance(video, torch.Tensor), f"{key} must be a torch.Tensor"
            assert video.ndim == 4, f"{key} must have shape (T, C, H, W)"
            T, C, H, W = video.shape

            # ensure the video resolution is larger than crop_size
            assert (
                H >= ch and W >= cw
            ), f"Cannot center-crop size {self.crop_size} from video of size {(H, W)}, please check the video resolution at {results['video_metadata']['video_path']}"

            # crop center
            top = (H - ch) // 2
            left = (W - cw) // 2

            # crop the video
            cropped = video[:, :, top : top + ch, left : left + cw]
            results[key] = cropped

        # update the metadata
        results["video_metadata"].update(
            {
                "height": ch,
                "width": cw,
            }
        )
        return results


@TRANSFORMS.register_module()
class DivisableCropVideo(BaseTransform):
    def __init__(
        self,
        video_keys: Union[str, List[str]] = ["video", "landmark_video", "mask_video"],
        divisor: int = 32,
    ) -> None:
        """For many modules, like Video VAEs, the input resolution must be divisible by a certain number(like 8 or 32).
         This transform is used to crop the video to make the resolution divisible by divisor.
        Args:
            video_keys (Union[str, List[str]]): Keys of videos to be cropped.
            divisor (int): The divisor of the resolution. (32 for LTX-Video)
        """
        super().__init__()
        self.video_keys = video_keys
        assert (
            isinstance(divisor, int) and divisor > 0
        ), f"divisor must be a positive int, but got {divisor}"
        self.divisor = divisor

    def transform(self, results: Dict) -> Dict:
        """
        Args:
            results (dict): contains video tensors under keys in self.video_keys.
                Each tensor shape: (T, C, H, W).
        Returns:
            dict: with each video center-cropped in-place so that
                  its H and W are divisible by self.divisor.
                  Also adds:
                    - results['divisible_shape'] = (new_h, new_w)
                    - results['divisible_offset'] = (top, left)
        """
        # We'll compute new_h, new_w only once since all videos share same shape
        # Gather and check
        sizes = [results[k].shape[-2:] for k in self.video_keys if k in results]
        assert sizes, "No valid video keys found in results."
        assert (
            len({tuple(s) for s in sizes}) == 1
        ), f"All videos must have same H×W, got {sizes}"

        H, W = sizes[0]
        # Compute largest divisible dims ≤ original
        new_h = (H // self.divisor) * self.divisor
        new_w = (W // self.divisor) * self.divisor

        # If already divisible, nothing to do
        if new_h == H and new_w == W:
            return results

        # Compute center crop offsets
        top = (H - new_h) // 2
        left = (W - new_w) // 2

        # Apply crop
        for key in self.video_keys:
            if key not in results:
                continue
            video = results[key]
            assert isinstance(video, torch.Tensor), f"{key} must be a torch.Tensor"
            assert video.ndim == 4, f"{key} must have shape (T, C, H, W)"
            # center-crop via slicing
            results[key] = video[:, :, top : top + new_h, left : left + new_w]

        results["video_metadata"].update(
            {
                "height": new_h,
                "width": new_w,
            }
        )

        return results


if __name__ == "__main__":

    def make_video_tensor(H, W, T=16, C=3):
        """Helper: create a dummy video tensor of shape (T, C, H, W)."""
        return torch.randn(T, C, H, W)

    print("Running CenterCropVideo self‐tests…")
    # 1) 正常居中裁剪
    vid = make_video_tensor(600, 800)
    cropper = CenterCropVideo(crop_size=(400, 500))
    out = cropper.transform(
        {"video": vid, "video_metadata": {"video_path": "dummy.mp4"}}
    )
    assert out["video"].shape[-2:] == (400, 500), "CenterCrop failed"

    # 2) 裁剪尺寸等于原始尺寸，不改变
    vid2 = make_video_tensor(256, 256)
    cropper2 = CenterCropVideo(crop_size=(256, 256))
    out2 = cropper2.transform(
        {"video": vid2, "video_metadata": {"video_path": "dummy.mp4"}}
    )
    assert out2["video"].shape[-2:] == (256, 256), "No-op crop failed"

    # 3) 缺少指定 key 时跳过
    vid3 = make_video_tensor(300, 300)
    cropper3 = CenterCropVideo(video_keys=["foo", "bar"], crop_size=(100, 100))
    # 应直接返回，不抛错
    out3 = cropper3.transform(
        {"video": vid3, "video_metadata": {"video_path": "dummy.mp4"}}
    )
    assert "foo" not in out3 and "bar" not in out3, "Unexpected keys added"

    # 4) 输入尺寸小于裁剪尺寸时抛 AssertionError
    try:
        small_vid = make_video_tensor(100, 100)
        CenterCropVideo(crop_size=(200, 200)).transform(
            {"video": small_vid, "video_metadata": {"video_path": "dummy.mp4"}}
        )
        raise RuntimeError("Expected AssertionError for too-large crop_size")
    except AssertionError:
        pass

    print("✔ All CenterCropVideo tests passed!")

    print("Running DivisableCropVideo self‐tests…")
    # 1) 已经可整除，无改动
    vid4 = make_video_tensor(256, 320)
    div_crop1 = DivisableCropVideo(divisor=32)
    out4 = div_crop1.transform({"video": vid4})
    assert out4["video"].shape[-2:] == (256, 320), "No-op divisable crop failed"
    # metadata 应保持旧值或不设置
    assert out4.get("video_metadata", None) is None

    # 2) 裁剪到可整除尺寸
    vid5 = make_video_tensor(250, 330)
    div_crop2 = DivisableCropVideo(divisor=32)
    out5 = div_crop2.transform({"video": vid5})
    # 250//32=7→7*32=224, 330//32=10→10*32=320
    assert out5["video"].shape[-2:] == (224, 320), "Divisable crop failed"
    assert out5["video_metadata"] == {
        "height": 224,
        "width": 320,
    }, "Metadata not updated"

    # 3) 多 key 一致性
    vA = make_video_tensor(150, 150)
    vB = make_video_tensor(150, 150)
    div_crop3 = DivisableCropVideo(video_keys=["video", "mask_video"], divisor=50)
    out6 = div_crop3.transform({"video": vA, "mask_video": vB})
    # 150//50=3→3*50=150, no crop
    assert out6["video"].shape[-2:] == (150, 150)
    assert out6["mask_video"].shape[-2:] == (150, 150)

    # 4) 没有有效 key 时抛 AssertionError
    try:
        DivisableCropVideo(video_keys=["foo"]).transform({})
        raise RuntimeError("Expected AssertionError for missing video keys")
    except AssertionError:
        pass

    print("✔ All DivisableCropVideo tests passed!")
