""""""

from __future__ import annotations

from typing import Tuple

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
from PIL import Image

from ..collection import nyuv2 as nyuv2_collection


class NYUv2Dataset(Dataset):
    def __init__(self, split: str, cache_dir: str, image_size: int = 224):
        self.ds = load_dataset("0jl/NYUv2", split=split, cache_dir=cache_dir)
        self.image_tf = transforms.Compose(
            [
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
            ]
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        sample = self.ds[idx]
        image = Image.fromarray(sample["image"])
        depth = torch.tensor(sample["depth"]).unsqueeze(0).float()  # H x W -> 1 x H x W
        label = torch.tensor(sample["label"], dtype=torch.long)
        image = self.image_tf(image)
        return {"rgb": image, "depth": depth}, label


def create_dataloaders(
    root=None,
    batch_size: int = 4,
    num_workers: int = 2,
    download: bool = True,
    image_size: int = 224,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    meta = nyuv2_collection.download_and_prepare(root=root, download=download)
    cache_dir = meta["hf_cache"]

    train_ds = NYUv2Dataset("train", cache_dir, image_size)
    val_ds = NYUv2Dataset("validation", cache_dir, image_size)
    test_ds = NYUv2Dataset("test", cache_dir, image_size)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader


__all__ = ["create_dataloaders", "NYUv2Dataset"]
