from torchvision import datasets
from torchvision.transforms import v2 as transforms

from .image_utils import IndexedDataset
import logging
txt_logger = logging.getLogger("sfda_reg")

def get_digits(fetch_dset, domain, path):
    train_aug_type = fetch_dset.get("train_aug_type", 'basic')
    aug_type = fetch_dset.get("aug_type", "basic")
    train_transform = get_digits_transforms(domain, train_aug_type)
    val_transform = get_digits_transforms(domain, "val")
    aug_transform = get_digits_transforms(domain, aug_type)
    txt_logger.info(f"Train aug is [{train_aug_type}] | additional aug [{aug_type}] | val aug type is val.")
    if domain == "svhn":
        train_ds, train_aug_ds, val_ds, val_aug_ds = get_svhn(train_transform, val_transform, aug_transform, path)
        return train_ds, train_aug_ds, val_ds, val_aug_ds
    elif domain == "mnist":
        train_ds, train_aug_ds, val_ds, val_aug_ds = get_mnist(train_transform, val_transform, aug_transform, path)
        return train_ds, train_aug_ds, val_ds, val_aug_ds


def get_svhn(train_transform, val_transform, aug_transform, svhn_path):
    
    train_ds = datasets.SVHN(
        svhn_path,
        split="train",
        transform=train_transform,
    )
    train_aug_ds = datasets.SVHN(
        svhn_path,
        split="train",
        transform=aug_transform,
    )
    val_ds = datasets.SVHN(svhn_path, split="test", transform=val_transform)
    val_aug_ds = datasets.SVHN(svhn_path, split="test", transform=aug_transform)
    
    return IndexedDataset(train_ds), IndexedDataset(train_aug_ds), IndexedDataset(
        val_ds), IndexedDataset(val_aug_ds)


def get_mnist(train_transform, val_transform, aug_transform, mnist_path):
    train_ds = datasets.MNIST(mnist_path, train=True, transform=train_transform)
    train_aug_ds = datasets.MNIST(mnist_path, train=True, transform=aug_transform)
    val_ds = datasets.MNIST(mnist_path, train=False, transform=val_transform)
    val_aug_ds = datasets.MNIST(mnist_path, train=False, transform=aug_transform)
    
    return IndexedDataset(train_ds), IndexedDataset(train_aug_ds), IndexedDataset(
        val_ds), IndexedDataset(val_aug_ds)


def get_digits_transforms(name: str, aug_type="basic"):
    match name:
        case "svhn":
            match aug_type:
                case "basic":
                    return transforms.Compose(
                        [transforms.RandomCrop(32, padding=4),
                         transforms.ToTensor()])
                case "val":
                    return transforms.ToTensor()
        case "mnist":
            match aug_type:
                case "basic":
                    return transforms.Compose(
                        [
                            transforms.Grayscale(3),
                            transforms.Resize(32),
                            transforms.RandomCrop(32, padding=4),
                            transforms.ToTensor()
                        ])
                case "val":
                    return transforms.Compose(
                        [
                            transforms.Grayscale(3),
                            transforms.Resize(32),
                            transforms.ToTensor()
                        ])
        case _:
            raise ValueError(f"Invalid dataset: {name!r}")
