from typing import Dict, List, Union, Tuple
import torch
from mmcv import BaseTransform
from mmhug.registry import TRANSFORMS


@TRANSFORMS.register_module(force=True)
class NormalizeVideo(BaseTransform):
    """
    Normalize or rescale a video tensor solely based on provided mean/std.

    - If tensor max > 1.0 **and** mean.max() ≤ 1.0: divides input by 255.0.
    - Always applies (video - mean)/std.

    Examples:
      * Gaussian normalize [0,1]→zero-mean unit-variance:
        mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225).
      * Rescale [0,1]→[−1,1]: mean=(0.5,…), std=(0.5,…), so 0→(0−0.5)/0.5=−1; 1→+1.
      * Rescale [0,255]→[−1,1]: mean=(127.5,…), std=(127.5,…), so 0→(0−127.5)/127.5=−1; 255→+1.
    """

    def __init__(
        self,
        mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
        std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
        video_keys: Union[str, List[str]] = "video",
        scaling: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Store mean/std as tensors shaped (1, C, 1, 1)
        self.mean = torch.tensor(mean, dtype=torch.float32).view(1, -1, 1, 1)
        self.std = torch.tensor(std, dtype=torch.float32).view(1, -1, 1, 1)
        self.scaling = scaling
        # Normalize video_keys to a list
        if isinstance(video_keys, str):
            self.video_keys = [video_keys]
        else:
            self.video_keys = list(video_keys)

    def transform(self, results: Dict) -> Dict:
        """
        Apply normalization/rescale to all videos in `results` under `video_keys`.
        Updates `results["video_metadata"]` with:
          - "mean": used mean
          - "std": used std
        """
        for key in self.video_keys:
            if key not in results:
                continue
            vid = results[key]
            assert torch.is_tensor(vid), f"{key} must be a torch.Tensor"
            assert vid.ndim == 4, f"{key} must have shape (T,C,H,W)"
            vid = vid.float()

            if self.scaling:
                vid = vid / 255.0
                vid = torch.clamp(vid, 0, 1)

            # Standardize: (video - mean)/std
            vid = (vid - self.mean) / self.std

            results[key] = vid

        # Update metadata
        results["video_metadata"].update(
            mean=self.mean.view(-1),
            std=self.std.view(-1),
        )
        return results


if __name__ == "__main__":
    import pytest

    def make_dummy(T=2, C=3, H=4, W=5, high=False):
        """Generate a dummy video: if high=True, values in [0,255], else in [0,1]."""
        return (torch.rand(T, C, H, W) * 255.0) if high else torch.rand(T, C, H, W)

    # 1) Test mapping ranges to [-1,1] or similar
    def test_range_rescale():
        # [0,1] -> [-1,1]
        norm = NormalizeVideo(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        vid = make_dummy(high=False)
        out = norm.transform({"video": vid.clone()})
        v = out["video"]
        assert torch.all(v.min() > -1), v.min()
        assert torch.all(v.max() < 1), v.max()

        # [0,255] -> [-1,1]
        norm2 = NormalizeVideo(mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5))
        vid2 = make_dummy(high=True)
        out2 = norm2.transform({"video": vid2.clone()})
        v2 = out2["video"]
        assert torch.all(v.min() > -1), v.min()
        assert torch.all(v.max() < 1), v.max()

        print("✔ test_range_rescale passed")

    # 2) Test Gaussian-style normalization statistics
    def test_gaussian_stats():
        norm = NormalizeVideo(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        vid = make_dummy(high=False).repeat(10, 1, 1, 1)  # more data
        out = norm.transform({"video": vid.clone()})
        v = out["video"]
        m, s = float(v.mean()), float(v.std())
        print(m, s)
        print("✔ test_gaussian_stats passed")

    # 3) Test that only specified keys are normalized
    def test_key_filtering():
        vid = make_dummy()
        other = torch.ones(1)
        results = {"video": vid.clone(), "foo": other.clone(), "video_metadata": {}}
        norm = NormalizeVideo(
            mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), video_keys=["video"]
        )
        out = norm.transform(results)
        assert torch.allclose(out["foo"], other)
        assert not torch.allclose(out["video"], vid)
        print("✔ test_key_filtering passed")

    # 4) Test invalid shape raises
    def test_invalid_shape():
        norm = NormalizeVideo()
        try:
            norm.transform({"video": torch.rand(3, 4, 5)})  # 3D tensor
            raise AssertionError("Expected shape assertion")
        except AssertionError:
            print("✔ test_invalid_shape passed")

    # Run all tests
    test_range_rescale()
    test_gaussian_stats()
    test_key_filtering()
    test_invalid_shape()

    print("=== All NormalizeVideo tests passed! ===")
