"""
Get source domain and shift domain from a dataset.

func:
    get_gradual_domains: Get domains from a dataset with gradual changes
    get_source_domains: Get source domains from a dataset
    get_corruption_domains: Get corruption domains from a dataset
    
data_name:
    gradual_domains: "rotate_mnist", "color_mnist", "portraits", "covertype"
    corruption_domains: "cifar10", "cifar100", "imagenet"

"""

from torch.utils.data import Dataset
from typing import Tuple, Union
from pathlib import Path

from config import get_config
cf = get_config(["config/path.yaml"])


gradual_domains = ["rotate_mnist", "color_mnist", "portraits", "covertype"]
corruption_domains = ["cifar10", "cifar100", "imagenet"]

def get_dataset_shape(data_name: str) -> Tuple[Tuple, int]:
    dataset_shape = {
        "rotate_mnist": ((1, 28, 28), 10),
        "color_mnist": ((1, 28, 28), 10),
        "portraits": ((3, 32, 32), 2),
        "covertype": ((54,), 2),
        "cifar10": ((3, 32, 32), 10),
        "cifar100": ((3, 32, 32), 100),
        "imagenet": ((3, -1, -1), 1000),
    }
    try:
        return dataset_shape[data_name]
    except KeyError:
        raise ValueError(f"Dataset {data_name} not supported.")

def get_gradual_domains(data_name: str, domains_num: int, idx: int = None, corruption: str = None) -> Union[list[Dataset], Dataset]:
    """Get domains from a dataset with gradual changes.

    Args:
        data_name (str): Name of dataset, one of ["mnist", "color_mnist", "portraits", "covertype", "cifar10", "cifar100", "imagenet"]
        domains_num (int): Number of domains to generate
        data_dir (str): Path to data directory
        idx (int, optional): Index of the domain to get. Defaults to None. None means all domains.
        corruption (str, optional): Corruption type. Defaults to None.
    Raises:
        ValueError: If dataset name is not supported

    Returns:
        list[Dataset]: List of domain datasets if idx is not None, otherwise a single domain dataset
    """
    if data_name not in gradual_domains and data_name not in corruption_domains:
        raise ValueError(f"Dataset {data_name} not supported.")
    data_dir = cf.dataset[data_name]["dir"]
    if idx is None:
        if data_name == "rotate_mnist":
            from .rotate_mnist import get_domains
            domains = get_domains(data_dir, domains_num, max_degree=45)
        elif data_name == "color_mnist":
            from .color_mnist import get_domains
            domains = get_domains(data_dir, domains_num, max_shift=1)
        elif data_name == "portraits":
            from .portraits import get_domains
            domains = get_domains(data_dir, domains_num, target_size=(32, 32))
        elif data_name == "covertype":
            from .covertype import get_domains
            domains = get_domains(data_dir, domains_num)
        else:
            if corruption is None:
                raise ValueError(f"Corruption type is not specified for '{data_name}'.")
            data_c_dir = cf.dataset[data_name + "c"]["dir"]
            if data_name == "cifar10":
                from .cifar10 import get_domains
                domains = get_domains(data_dir, data_c_dir, domains_num, corruption)
            elif data_name == "cifar100":
                from .cifar100 import get_domains
                domains = get_domains(data_dir, data_c_dir, domains_num, corruption)
            elif data_name == "imagenet":
                from .imagenet import get_domains
                domains = get_domains(data_dir, data_c_dir, domains_num, corruption)
        return domains
    else:
        if data_name == "rotate_mnist":
            from .rotate_mnist import get_domain
            domain = get_domain(data_dir, domains_num, max_degree=45, idx=idx)
        elif data_name == "color_mnist":
            from .color_mnist import get_domain
            domain = get_domain(data_dir, domains_num, max_shift=1, idx=idx)
        elif data_name == "portraits":
            from .portraits import get_domain
            domain = get_domain(data_dir, domains_num, target_size=(32, 32), idx=idx)
        elif data_name == "covertype":
            from .covertype import get_domain
            domain = get_domain(data_dir, domains_num, idx=idx)
        else:
            if corruption is None:
                raise ValueError(f"Corruption type is not specified for {data_name}.")
            data_c_dir = cf.dataset[data_name + "c"]["dir"]
            if data_name == "cifar10":
                from .cifar10 import get_domain
                domain = get_domain(data_dir, data_c_dir, domains_num, corruption, idx=idx)
            elif data_name == "cifar100":
                from .cifar100 import get_domain
                domain = get_domain(data_dir, data_c_dir, domains_num, corruption, idx=idx)
            elif data_name == "imagenet":
                from .imagenet import get_domain
                domain = get_domain(data_dir, data_c_dir, domains_num, corruption, idx=idx)
        return domain

def get_source_domains(data_name: str, examples_num: int=None) -> Tuple[Dataset, Dataset]:
    """Get source domains from a dataset.

    Args:        
        data_name (str): Name of dataset, one of ["mnist", "color_mnist", "portraits", "covertype", "cifar10", "cifar100", "imagenet"]
        data_dir (str): Path to data directory, if None, use the default data directory in config
        examples_num (int): Number of examples to load from each domain (Not yet used ! )

    Raises:
        ValueError: If dataset name is not supported

    Returns:
        Tuple[Dataset, Dataset]: Training and test datasets
    """
    data_dir = cf.dataset[data_name]["dir"]
    if examples_num is not None:
        print(f"[Warn] examples_num is not used yet, same as the whole dataset")
    if data_name == "cifar10":
        from .cifar10 import load_cifar10
        tr, ts = load_cifar10(data_dir)
    elif data_name == "cifar100":
        from .cifar100 import load_cifar100
        tr, ts = load_cifar100(data_dir)
    elif data_name == "imagenet":
        from .imagenet import load_imagenet
        tr, ts = load_imagenet(data_dir)
    elif data_name == "color_mnist":
        from .color_mnist import get_source
        tr, ts = get_source(data_dir)
    elif data_name == "rotate_mnist":
        from .rotate_mnist import get_source
        tr, ts = get_source(data_dir)
    elif data_name == "portraits":
        from .portraits import get_source
        tr, ts = get_source(data_dir, target_size=(32, 32))
    elif data_name == "covertype":
        from .covertype import get_source
        tr, ts = get_source(data_dir)
    else:
        raise ValueError(f"Dataset {data_name} not supported.")
    return tr, ts

def get_corruption_domains(data_name: str, examples_num: int, severity: int, corruptions: list[str]) -> Dataset:
    """Get corruption domains from a dataset.

    Args:
        data_dir (str): Path to data directory
        data_name (str): Name of dataset, one of ["cifar10", "cifar100", "imagenet"]
        examples_num (int): Number of examples to load from each domain
        severity (int): Severity of corruption
        corruptions (list[str]): List of corruption types

    Raises:
        ValueError: If dataset name is not supported

    Returns:
        Dataset: Corrupted domain dataset arranged in the order of corruptions
    """
    if data_name[-1] != "c":
        data_name = data_name + "c"
    data_dir = cf.dataset[data_name]["dir"]
    if data_name == "cifar10c":
        from .cifar10 import load_cifar10c
        d = load_cifar10c(data_dir, corruptions, severity, examples_num, False)
    elif data_name == "cifar100c":
        from .cifar100 import load_cifar100c
        d = load_cifar100c(data_dir, corruptions, severity, examples_num, False)
    elif data_name == "imagenetc":
        from .imagenet import load_imagenetc
        d = load_imagenetc(data_dir, corruptions, severity, examples_num, False)
    else:
        raise ValueError(f"Dataset {data_name} not supported.")
    return d


__all__ = ["get_dataset_shape", "get_gradual_domains", "get_source_domains", "get_corruption_domains", "gradual_domains", "corruption_domains"]