import itertools
import torch.nn as nn
MNODE_GL_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[1e-6,1e-7],\
           'a2':[10**(-i) for i in range(6,8)],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_GL']}

MNODE_NR_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[10**(-i) for i in range(6,8)],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_NR']}

MNODE_EGL_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[10**(-i) for i in range(6,8)],\
           'a2':[0],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_EGL']}

MNODE_EN_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[1e-6, 1e-7],\
           'a2':[10**(-i) for i in range(6,8)],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_EN']}

MNODE_GD_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[1e-6],\
           'lr':[1e-3],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_GD']}

MNODE_RD_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[1e-6],\
           'lr':[1e-3],\
           'p':[0.1,0.2,0.4],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_RD']}

MNODE_NS_p={'num_hidden_layers':[2],\
           'mlp_size':[16],\
           'dropout':[0],\
           'a1':[0],\
           'a2':[1e-3],\
           'lr':[1e-3],\
           'k':[2+2*i for i in range(4)],\
           'input_size':[5],\
           'activation':[nn.ReLU],\
           'name':['MNODE_NS']}

LSTM_NR_p={'num_layers':[2,3],\
           'latent_size':[6,12,18],\
           'dropout':[0,0.1,0.2],\
           'input_size':[4],\
           'a1':[0],\
           'a2':[0],\
           'lr':[1e-3],\
           'activation':[nn.ReLU],\
           'name':['LSTM_NR']}

BNODE_NR_p={'num_hidden_layers':[2],\
            'latent_size':[6,12,18],\
            'mlp_size':[16],\
            'dropout':[0,0.1,0.2],\
            'input_size':[4],\
            'a1':[0],\
            'a2':[0],\
            'lr':[1e-4,1e-3,1e-2],\
            'activation':[nn.ReLU],\
            'name':['BNODE_NR']}

S4D_NR_p={'d_model':[4,6,8],\
          'd_state':[32,64],\
          'dropout':[0,0.1,0.2],\
          'input_size':[5],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-4,1e-3,1e-2],\
          'name':['S4D_NR']}

TCN_NR_p={'num_layers':[2,3],\
          'conv_size':[16,32],\
          'kernel_size':[2,3,4],\
          'dropout':[0,0.1,0.2],\
          'input_size':[5],\
          'a1':[0],\
          'a2':[0],\
          'lr':[1e-4,1e-3,1e-2],\
          'name':['TCN_NR']}

Transformer_TS_p={'d_model':[8,16],\
                  'nhead':[4,8],\
                  'num_encoder_layers':[2],\
                  'num_decoder_layers':[2],\
                  'dim_feedforward':[16,32],\
                  'dropout':[0,0.1],\
                  'input_size':[5],\
                  'lr':[1e-3],\
                  'name':['Transformer_TS']}



hyper_param_dicts={}
for param_grid in [MNODE_GL_p, MNODE_NR_p, MNODE_GD_p, MNODE_NS_p, LSTM_NR_p, MNODE_EGL_p, MNODE_EN_p, MNODE_RD_p,\
                   LSTM_NR_p, BNODE_NR_p, S4D_NR_p, TCN_NR_p, Transformer_TS_p,]:

    keys = param_grid.keys()
    values = param_grid.values()
    hyper_param_dicts[param_grid['name'][0]] = [dict(zip(keys, v)) for v in itertools.product(*values)]
