from utils.constants import Cte


def get_dict_carefl(dataset_name, experiment_name):
    if dataset_name in [Cte.TRIANGLE, Cte.COLLIDER, Cte.CHAIN, Cte.MGRAPH]:
        if experiment_name in ['all']:
            dataset_dict = {'normalize': ['std'],
                            'equations_type': ['linear', 'non-linear', 'non-additive'],
                            'num_samples_tr': [5000],
                            }
            model_dict = {'n_layers': [2, 4, 5],
                          'n_hidden': [10, 32]
                          }
    elif dataset_name == Cte.LOAN:
        if experiment_name in ['all']:
            dataset_dict = {'normalize': ['std'],
                            'equations_type': ['linear'],
                            'num_samples_tr': [5000],
                            }
            model_dict = {'n_layers': [2, 4, 5],
                          'n_hidden': [10, 32]
                          }
    else:
        raise NotImplementedError

    return dataset_dict, model_dict


def get_dict_vcause(dataset_name, experiment_name):
    if dataset_name in [Cte.TRIANGLE, Cte.CHAIN]:  # Diameter D=2
        if experiment_name in ['all']:  # Diameter
            dataset_dict = {'normalize': ['std'],
                            'equations_type': ['linear', 'non-linear', 'non-additive'],
                            'num_samples_tr': [5000],
                            }
            model_dict = {'architecture': ['dgnn'],
                          'h_dim_list_dec': [[], [16], [16, 16]],
                          'm_layers': [1],
                          'h_dim_list_enc': [[16], [16, 16]],
                          'z_dim': [4],
                          'act_name': ['relu'],
                          'drop_rate': [0.0],
                          'dropout_adj_rate': [0.0, 0.1],
                          'dropout_adj_I_rate': [0.1, 0.2],
                          'dropout_adj_I_prob_keep_self': [0.0, 0.2],
                          'dropout_input_rate': [0.0],
                          'beta': [1],
                          'K': [5],
                          'lambda_kld': [0.05],
                          'distr_x': ['delta'],
                          }



    elif dataset_name == Cte.LOAN:
        if experiment_name in ['all']:
            dataset_dict = {'normalize': ['std'],
                            'equations_type': ['linear'],
                            'num_samples_tr': [5000],
                            }

            model_dict = {'architecture': ['dgnn'],
                          'h_dim_list_dec': [[16], [16, 16], [16, 16, 16]],
                          'm_layers': [1],
                          'h_dim_list_enc': [[16], [16, 16], [16, 16]],
                          'z_dim': [4],
                          'act_name': ['relu'],
                          'drop_rate': [0.0],
                          'dropout_adj_rate': [0.0, 0.2],
                          'keep_self_loops': [1],
                          'dropout_adj_I_rate': [0.2],
                          'dropout_adj_I_prob_keep_self': [0.0, 0.2],
                          'dropout_input_rate': [0.0],
                          'beta': [1],
                          'residual': [0],
                          'K': [5],
                          'lambda_kld': [0.05],
                          'distr_x': ['delta'],
                          }


    elif dataset_name in [Cte.COLLIDER, Cte.MGRAPH]:
        # Dataset dict

        if experiment_name in ['all']:  # Diameter
            dataset_dict = {'normalize': ['std'],
                            'equations_type': [Cte.LINEAR, Cte.NONLINEAR, Cte.NONADDITIVE],
                            'num_samples_tr': [5000],
                            }
            model_dict = {'architecture': ['dgnn'],
                          'h_dim_list_dec': [[], [16], [16, 16]],
                          'm_layers': [1],
                          'h_dim_list_enc': [[16]],
                          'z_dim': [4],
                          'act_name': ['relu'],
                          'drop_rate': [0.0],
                          'dropout_adj_rate': [0.0, 0.1],
                          'dropout_adj_I_rate': [0.1, 0.2],
                          'dropout_adj_I_prob_keep_self': [0.0],
                          'dropout_input_rate': [0.0],
                          'beta': [1],
                          'K': [5],
                          'lambda_kld': [0.05],
                          'distr_x': ['delta'],
                          }

    else:
        raise NotImplementedError

    return dataset_dict, model_dict


def get_dict_mcvae(dataset_name, experiment_name):
    if dataset_name in [Cte.TRIANGLE, Cte.COLLIDER, Cte.CHAIN]:
        if experiment_name in ['all']:
            dataset_dict = {'normalize': ['std'],
                            'equations_type': [Cte.LINEAR, Cte.NONLINEAR, Cte.NONADDITIVE],
                            'num_samples_tr': [5000],
                            }

            model_dict = {'z_dim': [1],  # will be overwritten
                          'h_dim_list_dec': [[32, 32]],  # will be overwritten for triangle
                          'h_dim_list_enc': [[32, 32]],  # will be  overwritten for triangle
                          'lambda_kld': [0.05]  # will be overwritten
                          }
            trainer_dict = {'max_epochs': 900  # 300 for each node,
                            }
        else:
            raise NotImplementedError
    elif dataset_name in [Cte.MGRAPH]:
        if experiment_name in ['all']:
            dataset_dict = {'normalize': ['std'],
                            'equations_type': [Cte.LINEAR, Cte.NONLINEAR, Cte.NONADDITIVE],
                            'num_samples_tr': [5000],
                            }

            model_dict = {'z_dim': [1],  # will be overwritten
                          'h_dim_list_dec': [[32, 32]],  # will be overwritten for triangle
                          'h_dim_list_enc': [[32, 32]],  # will be  overwritten for triangle
                          'lambda_kld': [0.05]  # will be overwritten
                          }
            trainer_dict = {'max_epochs': 1500  # 300 for each node,
                            }
        else:
            raise NotImplementedError

    elif dataset_name == Cte.LOAN:
        if experiment_name in ['all']:
            dataset_dict = {'normalize': ['std'],
                            'num_samples_tr': [5000],
                            }
            model_dict = {'z_dim': [1],  # will be overwritten
                          'h_dim_list_dec': [[32, 32]],  # will be overwritten
                          'h_dim_list_enc': [[32, 32]],  # will be  overwritten
                          'lambda_kld': [0.05]  # will be overwritten
                          }

            trainer_dict = {'max_epochs': 2100,  # 300 for each node,
                            }

        else:
            raise NotImplementedError

    else:
        raise NotImplementedError

    return dataset_dict, model_dict, trainer_dict
