from laplace import DiagLaplace, KronLaplace


LAPLACE_CLASSES = {
    'diaglaplace': DiagLaplace,
    'kronlaplace': KronLaplace,
}

def none_or_str(value):
    if value.lower() == 'none':
        return None
    return value

def none_or_int(value):
    if value.lower() == 'none':
        return None
    return int(value)

def none_or_float(value):
    if value.lower() == 'none':
        return None
    return float(value)

def str_to_bool(value):
    return value.lower() in ['true', 't', '1', 'yes', 'y']


# parser = argparse.ArgumentParser("Marglik optimization on graphs.")
def parse_add_args(parser):
    parser.add_argument('--dataset', type=str,
                        help='Name of the dataset.')
    parser.add_argument('--model', type=str, default='gcn',
                        choices=['gcn', 'graphsage'],
                        help='Model to use.')
    parser.add_argument('--hidden_channels', type=int, default=128,
                        help='Number of hidden channels.')
    parser.add_argument('--num_layers', type=int, default=2,
                        help='Number of layers.')
    parser.add_argument('--dropout', type=float, default=0.5,
                        help='Dropout rate.')
    parser.add_argument('--act', type=none_or_str, default='relu',
                        choices=['relu', 'tanh', 'gelu', 'softplus'],
                        help='Activation function.')
    parser.add_argument('--norm', type=none_or_str, default=None,
                        choices=[None, 'batch', 'layer'],
                        help='Normalization layer.')
    parser.add_argument('--feat_norm', type=str_to_bool, default=False,
                        help='Normalize features.')
    parser.add_argument('--jk', type=none_or_str, default=None,
                        choices=[None, 'cat', 'max', 'lstm'],
                        help='Jumping knowledge.')
    parser.add_argument('--res', type=str_to_bool, default=False,
                        help='Residual connection.')
    parser.add_argument('--lr', type=float, default=0.05,
                        help='Learning rate.')
    parser.add_argument('--weight_decay', type=float, default=5e-5,
                        help='Weight decay.')
    parser.add_argument('--n_epochs', type=int, default=200,
                        help='Number of epochs.')
    parser.add_argument('--n_epochs_burnin', type=int, default=10,
                        help='Number of burn-in epochs.')
    parser.add_argument('--n_hypersteps', type=int, default=0,
                        help='Number of hyperparameter optimization steps.')
    parser.add_argument('--marglik_frequency', type=int, default=1,
                        help='Frequency of marginal likelihood computation.')
    parser.add_argument('--laplace', type=str, default='kronlaplace',
                        choices=LAPLACE_CLASSES.keys(),
                        help='Laplace approximation.')
    parser.add_argument('--lr_graph', type=float, default=0.1,
                        help='Learning rate for the graph.')
    parser.add_argument('--lr_graph_min', type=none_or_float, default=None,
                        help='Minimum learning rate for the graph.')
    parser.add_argument('--graph_grad_norm', type=str_to_bool, default=False,
                        help='Normalize the graph gradients.')
    parser.add_argument('--early_stop_crit', type=str, default=None,
                        choices=['valid_loss', 'valid_acc', 'marglik'],
                        help='Early stopping criteria.')
    parser.add_argument('--console', type=str_to_bool, default=True,
                        help='Log to console.')
    parser.add_argument('--job_id', type=str, default='default',
                        help='Job ID.')
    parser.add_argument('--cont_relax_temp', type=float, default=0.1,
                        help='Temperature for binary concrete sampling (graph discrete function).')
    parser.add_argument('--graph_prior', type=none_or_str, default=None,
                        choices=[None, 'bernoulli'],
                        help='Prior on graph based on the original graph (adds KL term to loss to marglik).')
    parser.add_argument('--obs_prior_edge_prob', type=float, default=0.5,
                        help='Prior probability for Bernoulli for existing edges in the original graph. ' + \
                            'Higher values biases towards the original graph.')
    parser.add_argument('--prior_non_edge_prob', type=float, default=0.0001,
                        help='Prior probability for Bernoulli for non-existing edges in the original graph.')
    parser.add_argument('--knn_prior_edge_k', type=int, default=0,
                        help='Number of k for KNN graph prior on the edges.')
    parser.add_argument('--knn_prior_edge_prob', type=float, default=0.5,
                        help='Prior probability for Bernoulli for KNN graph prior on the edges. ' + \
                             'It is NOT additive to the graph prior edge probability (of the original graph).')
    parser.add_argument('--knn_prior_edge_dist_metric', type=str, default='euclidean',
                        choices=['euclidean', 'manhattan', 'cosine'],
                        help='Distance metric for KNN graph prior on the edges.')
    parser.add_argument('--graph_kl_weight', type=float, default=1.,
                        help='Weight for the KL term in the graph loss function. ' + \
                             '< 1 leads to downweighting the KL term.')
    parser.add_argument('--log_det_weight', type=float, default=1.,
                        help='Weight for the log determinant term in the graph loss function. ' + \
                             '< 1 leads to downweighting the log determinant term.')
    parser.add_argument('--n_samples', type=int, default=1,
                        help='Number of samples for prediction.')
    parser.add_argument('--n_repeats', type=int, default=1,
                        help='Number of repeats for the experiment.')
    parser.add_argument('--log_frequency', type=int, default=20,
                        help='Frequency of logging.')
    parser.add_argument('--checkpoint_dir', type=none_or_str, default=None,
                        help='Directory to save checkpoints.')
    parser.add_argument('--split', nargs='+', type=int, default=None,
                        help='Split arguments (accepts multiple values).')
    parser.add_argument('--gpu', type=none_or_int, default=None,
                        help='GPU to use.')
    