import itertools
import torch.nn as nn
MNODE_GL_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[1e-5,1e-6,1e-7],\
           'a2':[10**(-i) for i in range(6,8)],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_GL']}

MNODE_GLR_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[0],\
           'lr':[1e-3],\
           'input_size':[5],\
           'rate':[0.05,0,-0.05,-0.1],\
           'activation':[nn.ReLU],\
           'name':['MNODE_GLR']}

MNODE_NR_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[0],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_NR']}

MNODE_EGL_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[10**(-i) for i in range(6,8)],\
           'a2':[0],\
           'lr':[1e-2,1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_EGL']}

MNODE_EN_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[1e-5, 1e-6, 1e-7],\
           'a2':[10**(-i) for i in range(6,8)],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_EN']}

MNODE_GD_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[1e-6],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_GD']}

MNODE_RD_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[1e-6],\
           'lr':[1e-3],\
           'p':[0.1,0.2,0.4],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_RD']}

MNODE_NS_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[1e-6],\
           'lr':[1e-3],\
           'k':[12+2*i for i in range(6)],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_NS']}

LSTM_NR_p={'num_layers':[2,3],\
           'latent_size':[6,12,18],\
           'dropout':[0,0.1,0.2],\
           'input_size':[4],\
           'a1':[0],\
           'a2':[0],\
           'lr':[1e-4,1e-3,1e-2],\
           'activation':[nn.ReLU],\
           'name':['LSTM_NR']}

BNODE_NR_p={'num_hidden_layers':[2],\
            'latent_size':[6,12,18],\
            'mlp_size':[16],\
            'dropout':[0,0.1,0.2],\
            'input_size':[4],\
            'a1':[0],\
            'a2':[0],\
            'lr':[1e-4,1e-3,1e-2],\
            'activation':[nn.ReLU],\
            'name':['BNODE_NR']}

S4D_NR_p={'d_model':[4,6,8],\
          'd_state':[32,64],\
          'dropout':[0,0.1,0.2],\
          'input_size':[5],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-4,1e-3,1e-2],\
          'name':['S4D_NR']}

TCN_NR_p={'num_layers':[2,3],\
          'conv_size':[16,32],\
          'kernel_size':[2,3,4],\
          'dropout':[0,0.1,0.2],\
          'input_size':[5],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-4,1e-3,1e-2],\
          'name':['TCN_NR']}

Transformer_TS_p={'d_model':[8,16],\
                  'nhead':[4,8],\
                  'num_encoder_layers':[2],\
                  'num_decoder_layers':[2],\
                  'dim_feedforward':[16,32],\
                  'dropout':[0,0.1],\
                  'input_size':[5],\
                  'lr':[1e-3],\
                  'name':['Transformer_TS']}

UVA_NR_p={'input_size':[4],\
          'latent_size':[21],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-1],\
          'name':['UVA_NR']}

LP_NR_p={'input_size':[4],\
          'num_hidden_layers':[2],\
          'mlp_size':[16],\
          'latent_size':[20,24,28],\
          'state_size':[20],\
          'dropout':[0.0,0.1,0.2],\
          'activation':[nn.ReLU],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-4,1e-3,1e-2],\
          'name':['LP_NR']}

SC_NR_p={'closure_input_sizes':[[2,1,2,2,2,2,2,2,2,4,2,2,2,2,2,5,2,2]],\
         'input_size':[4],\
          'num_hidden_layers':[2],\
          'mlp_size':[16],\
          'latent_size':[20,24],\
          'state_size':[20],\
          'dropout':[0.0,0.1,0.2],\
          'activation':[nn.ReLU],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-4,1e-3,1e-2],\
          'name':['SC_NR']}

hyper_param_dicts={}
for param_grid in [MNODE_GL_p, MNODE_GLR_p, MNODE_NR_p, MNODE_EGL_p, MNODE_EN_p, MNODE_GD_p, MNODE_RD_p, MNODE_NS_p,\
                   LSTM_NR_p, BNODE_NR_p, S4D_NR_p, TCN_NR_p, Transformer_TS_p, UVA_NR_p, LP_NR_p, SC_NR_p]:

    keys = param_grid.keys()
    values = param_grid.values()
    hyper_param_dicts[param_grid['name'][0]] = [dict(zip(keys, v)) for v in itertools.product(*values)]
