import numpy as np
import torch
from PIL import Image
from torchvision import datasets, transforms
from typing import Any, Tuple
from typing_extensions import override


SVHN_PARAMS = {
    'dataset_name': 'SVHN',
    'n_data': 73_257,
    'shape': (3, 32, 32),
    'mean': (0.43768212, 0.44376972, 0.47280444),
    'std': (0.19803013, 0.20101563, 0.19703615),
    'root': 'SVHN',
}

SVHN_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=SVHN_PARAMS['mean'], std=SVHN_PARAMS['std'])])


class _IndexedSVHN(datasets.SVHN):
    def __init__(
            self, train=True, indexed=True
    ):
        super().__init__(root=SVHN_PARAMS['root'], 
                         split='train' if train else 'test', 
                         transform=SVHN_TRANSFORM, 
                         target_transform=None, 
                         download=True)
        
        if not train:
            self.data = torch.from_numpy(self.data).float()
            self.targets = torch.from_numpy(self.labels).int()
        else:
            self.data = torch.from_numpy(self.data).float()
            targets = torch.from_numpy(self.labels).int()
            targets_labelled = torch.cat((targets.reshape(-1, 1), torch.arange(len(targets)).reshape(-1, 1)), dim=1)
            self.targets = targets_labelled

    @override
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.data[index], self.targets[index].int()
        
        img = Image.fromarray(torch.permute(img, (1, 2, 0)).numpy().astype(np.uint8))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
    

def load_svhn(which: str):
    """
    Loads the SVHN dataset.

    Args:
        which (str): `train` or `test`.

    Returns:
        The required version of SVHN dataset.
    """
    if which == 'train':
        return _IndexedSVHN(train=True, indexed=True)
    elif which == 'test':
        return _IndexedSVHN(train=False, indexed=False)
    else:
        raise ValueError("Can only choose between 'train' and 'test'.")