"""Defining the architectures used for benchmarking algorithms"""
from woods.models.lstm import LSTM, LSTM_Freq
from woods.models.BENDR import BENDR
from woods.models.CRNN import CRNN
from woods.models.deep4 import deep4, deep4_Freq
from woods.models.EEGNet import EEGNet
from woods.models.ForecastingTransformer import ForecastingTransformer
from woods.models.mnist import MNIST_CNN, MNIST_LSTM
from woods.models.BiModel import BiModel
from woods.models.FreTS import FreTS
from woods.models.PatchTST import PatchTST
from woods.models.FEDNet import FEDNet
from woods.models.AdaRNN import AdaRNN
from woods.models.GILE import GILE
from woods.models.ActNetwork import ActNetwork, ActFreqNetwork
from woods.models.FEDNet_wo_con import FEDNet_wo_con
from woods.models.FEDNet_wo_det import FEDNet_wo_det
from woods.models.FEDNet_wo_sto import FEDNet_wo_sto
from woods.models.FEDNet_fc import FEDNet_fc
# from woods.models.FEDNet_contrastive import FEDNet_contrastive
from woods.models.FEDNet import FEDNet
from woods.models.FEDNet_v2 import FEDNet_v2
from woods.models.FEDNet_v1 import FEDNet_v1


MODELS = {
    'LSTM' : LSTM,
    'BENDR': BENDR,
    'BiModel': BiModel,
    'CRNN': CRNN,
    'deep4': deep4,
    'EEGNet': EEGNet,
    'MNIST_CNN': MNIST_CNN,
    'MNIST_LSTM': MNIST_LSTM,
    'FreTS': FreTS,
    'PatchTST': PatchTST,
    'FEDNet': FEDNet,
    'AdaRNN': AdaRNN,
    'GILE': GILE,
    'ActNetwork': ActNetwork,
    'LSTM_Freq': LSTM_Freq,
    'deep4_Freq': deep4_Freq,
    'ActFreqNetwork': ActFreqNetwork,
    'FEDNet_wo_con': FEDNet_wo_con,
    'FEDNet_wo_det': FEDNet_wo_det,
    'FEDNet_wo_sto': FEDNet_wo_sto,
    'FEDNet_fc': FEDNet_fc,
    'FEDNet_v1': FEDNet_v1,
    'FEDNet_v2': FEDNet_v2
}

def get_model(dataset, model_hparams):
    """Return the dataset class with the given name
    
    Args:
        dataset (str): name of the dataset
        model_hparams (dict): model hyperparameters 
    """
    if model_hparams['model'] not in MODELS:
        raise NotImplementedError("MODELS list not found: {}".format(model_hparams['model']))

    model_fn = MODELS[model_hparams['model']]

    return model_fn(dataset, model_hparams)

















