from typing import List, Dict, Any, Callable, Sequence, Union
import torch
import torchvision.transforms.v2.functional as TF
import pytorch_lightning as pl
from diffusion.data.mini_videoloader import MinimalVideoLoader


def resize(video, width: int, height: int, keep_aspect_ratio: bool = True) -> torch.Tensor:
    if keep_aspect_ratio:
        out_ratio = width / height
        ratio = video.shape[-1] / video.shape[-2]
        if ratio > out_ratio:
            new_w = int(video.shape[-2] * out_ratio)
            o = (video.shape[-1] - new_w) // 2
            video = video[:, :, :, o : (o + new_w)]
        else:
            new_h = int(video.shape[-1] / out_ratio)
            o = (video.shape[-2] - new_h) // 2
            video = video[:, :, o : (o + new_h), :]
    video = TF.resize(video, size=(height, width), antialias=True)
    return video


def get_resize_transform(width: int, height: int, keep_aspect_ratio: bool) -> Callable[[torch.Tensor], torch.Tensor]:
    def _transform(frames: torch.Tensor) -> torch.Tensor:
        frames = resize(frames, width, height, keep_aspect_ratio)
        return frames

    return _transform


class VideoDataModule(pl.LightningDataModule):
    def __init__(self, train=None, validation=None, test=None):
        super().__init__()
        if train is not None:
            train = dict(train)
            train["transform"] = get_resize_transform(
                train.pop("width", 256), train.pop("height", 256), train.pop("keep_aspect_ratio", True)
            )
        if validation is not None:
            validation = dict(validation)
            validation["transform"] = get_resize_transform(
                validation.pop("width", 256), validation.pop("height", 256), validation.pop("keep_aspect_ratio", True)
            )
        if test is not None:
            test = dict(test)
            test["transform"] = get_resize_transform(
                test.pop("width", 256), test.pop("height", 256), test.pop("keep_aspect_ratio", True)
            )
        self.dataset_configs = {"train": train, "validation": validation, "test": test}

    def setup(self, stage: str):
        pass

    def train_dataloader(self):
        return MinimalVideoLoader(**self.dataset_configs["train"])

    def val_dataloader(self, device=None):
        return MinimalVideoLoader(**self.dataset_configs["validation"])

    def test_dataloader(self):
        return MinimalVideoLoader(**self.dataset_configs["test"])

    def teardown(self, stage: str):
        pass
