from typing import Any
import logging
import os.path as osp
import os
from torch.utils.data import Dataset

from .digits_datasets import get_digits
from .biwi_kinect import get_biwi_kinect
from .utkface import get_utkface
from .california_housing import get_california_housing

txt_logger = logging.getLogger("sfda_reg")


def get_datasets(basic_info_config: dict[str, Any], domain) -> tuple[Dataset, Dataset]:
    dset_name = basic_info_config['dset']
    datasource_dir = basic_info_config['data_dir']
    path = osp.join(osp.abspath(datasource_dir), dset_name)
    txt_logger.info(f"Dataset is [{dset_name}].")
    apply_normalization = basic_info_config.get('apply_normalization', True)

    match dset_name:
        case "digits":
            train_ds, train_aug_ds, val_ds, val_aug_ds = get_digits(basic_info_config, domain, path)
            return train_ds, train_aug_ds, val_ds, val_aug_ds
        case "biwi_kinect":
            train_ds, train_aug_ds, val_ds, val_aug_ds = get_biwi_kinect(
                basic_info_config, domain, path, apply_normalization)
            return train_ds, train_aug_ds, val_ds, val_aug_ds
        case "UTKface":
            train_ds, train_aug_ds, val_ds, val_aug_ds = get_utkface(basic_info_config, domain, path, apply_normalization)
        case "california":
            train_ds, train_aug_ds, val_ds, val_aug_ds = get_california_housing(
                basic_info_config, domain, path)
        case _:
            raise ValueError(f"Invalid dataset: {dset_name!r}")


    return train_ds, train_aug_ds, val_ds, val_aug_ds
