""""""

from __future__ import annotations

import numpy as np
import torch
from pathlib import Path
from typing import Tuple
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchaudio

from ..collection import avmnist as avmnist_collection


class AVMNISTDataset(Dataset):
    def __init__(self, root: Path, split: str = "train", image_size: int = 28):
        self.root = Path(root)
        self.split = split
        self.image_dir = self.root / "images"
        self.audio_dir = self.root / "audio"
        self.labels = np.load(self.root / "labels.npy")
        self.image_size = image_size

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = self.image_dir / f"{idx}.png"
        wav_path = self.audio_dir / f"{idx}.wav"

        image = Image.open(img_path).convert("L").resize((self.image_size, self.image_size))
        image = torch.tensor(np.array(image), dtype=torch.float32).unsqueeze(0) / 255.0

        waveform, sr = torchaudio.load(wav_path)
        # Pad or cut off @ 1 second
        target_len = sr
        if waveform.shape[-1] >= target_len:
            waveform = waveform[..., :target_len]
        else:
            pad = target_len - waveform.shape[-1]
            waveform = torch.nn.functional.pad(waveform, (0, pad))

        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return {"image": image, "audio": waveform}, label


def create_dataloaders(
    root=None,
    batch_size: int = 64,
    num_workers: int = 4,
    download: bool = True,
) -> Tuple[DataLoader, DataLoader]:
    meta = avmnist_collection.download_and_prepare(root=root, download=download)
    train_dir = Path(meta["splits"]["train"])
    test_dir = Path(meta["splits"]["test"])

    train_ds = AVMNISTDataset(train_dir, split="train")
    test_ds = AVMNISTDataset(test_dir, split="test")

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader


__all__ = ["create_dataloaders", "AVMNISTDataset"]
