from woods.modelparams.basic_fourier import Basic_Fourier_model
from woods.modelparams.spurious_fourier import Spurious_Fourier_model
from woods.modelparams.pcl import PCL_model
from woods.modelparams.lsa64 import LSA64_model
from woods.modelparams.hhar import HHAR_model
from woods.modelparams.tmnist import TMNIST_model, TCMNIST_Source_model, TCMNIST_Time_model
from woods.modelparams.dsads import DSADS_CROSSPOSITION_model
from woods.modelparams.cap import CAP_model
from woods.modelparams.sedfx import SEDFx_model
from woods.modelparams.auselectricity import AusElectricityUnbalanced_model, AusElectricity_model
from woods.modelparams.iemocap import IEMOCAPOriginal_model, IEMOCAPUnbalanced_model, IEMOCAP_model



DATASET_MODEL_HPARAMS = {
    "Basic_Fourier_model": Basic_Fourier_model,
    "Spurious_Fourier_model": Spurious_Fourier_model,
    "TMNIST_model": TMNIST_model,
    "TCMNIST_Source_model": TCMNIST_Source_model,
    "TCMNIST_Time_model": TCMNIST_Time_model,
    "PCL_model": PCL_model,
    "LSA64_model": LSA64_model,
    "HHAR_model": HHAR_model,
    'DSADS_CROSSPOSITION_model': DSADS_CROSSPOSITION_model
}


def get_model_hparams(dataset_name, model_name):
    """ Get the model related hyper parameters

    Each dataset has their own model hyper parameters definition

    Args:
        dataset_name (str): dataset that is gonna be trained on for the run
        seed (int): seed used if hyper parameter is sampled
        sample (bool, optional): If ''True'', hyper parameters are gonna be sampled randomly according to their given distributions. Defaults to ''False'' where the default value is chosen.

    Raises:
        NotImplementedError: Dataset name not found

    Returns:
        dict: Dictionnary with hyper parameters values
    """
    dataset_model = dataset_name + '_model'
    if dataset_model not in DATASET_MODEL_HPARAMS:
        raise NotImplementedError("dataset not found: {}".format(dataset_name))
    else:
        hyper_function = DATASET_MODEL_HPARAMS[dataset_model]

    hparams = hyper_function(model_name)
    
    return hparams