""""""

from __future__ import annotations

from typing import Tuple, Optional, Dict

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

from ..collection import mmimdb as mmimdb_collection


class MMIMDbDataset(Dataset):
    def __init__(self, split: str, cache_dir: str, tokenizer=None, max_length: int = 256):
        self.ds = load_dataset("mm_imdb", split=split, cache_dir=cache_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.genres = sorted(self.ds.features["label"].names)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        sample = self.ds[idx]
        text = sample["text"]
        labels = torch.tensor(sample["label"], dtype=torch.float32)
        if self.tokenizer is None:
            return text, labels
        enc = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        return enc, labels


def create_dataloaders(
    root=None,
    batch_size: int = 8,
    num_workers: int = 2,
    download: bool = True,
    tokenizer=None,
    max_length: int = 256,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    meta = mmimdb_collection.download_and_prepare(root=root, download=download)
    cache_dir = meta["hf_cache"]

    train_ds = MMIMDbDataset("train", cache_dir, tokenizer, max_length)
    val_ds = MMIMDbDataset("validation", cache_dir, tokenizer, max_length)
    test_ds = MMIMDbDataset("test", cache_dir, tokenizer, max_length)

    collate_fn = None
    if tokenizer is not None:
        def _collate(batch):
            inputs = {k: torch.stack([b[0][k] for b in batch]) for k in batch[0][0]}
            labels = torch.stack([b[1] for b in batch])
            return inputs, labels
        collate_fn = _collate

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
    return train_loader, val_loader, test_loader


__all__ = ["create_dataloaders", "MMIMDbDataset"]
