import argparse

def str2bool(s):
    if s not in {'False', 'True', 'false', 'true'}:
        raise ValueError('Not a valid boolean string')
    return (s == 'True') or (s == 'true')

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=int, help="Indexes of gpu to run program on", default=1)
    parser.add_argument("--server", type=int, help="Indexes of gpu to run program on", default=68)
    parser.add_argument("--dataset", type=str, help="Name of Dataset", default='adult')
    parser.add_argument("--corr", type=float, help="Threshold of correlation", default=0.0)
    parser.add_argument("--missing_type", type=str, help="Type of missing feature mask", default="uniform", choices=["uniform", "row", "col"])
    parser.add_argument("--missing_rate", type=float, help="Rate of node features missing", default=0.0)
    parser.add_argument("--train_ratio", type=float, help="Ratio of training samples", default=0.1)
    parser.add_argument("--lp_alpha", type=float, help="Alpha parameter of label propagation", default=0.99)
    parser.add_argument("--lamb", type=float, help="Control loss between metric loss and cross entropy loss", default=1.0) # 0.0, 0.1, 0.2
    parser.add_argument("--filling_method",type=str,help="Method to solve the missing feature problem", default="fp",choices=["random", "zero", "mean", "neighborhood_mean", "fp"])
    parser.add_argument("--embedder",type=str,help="Type of model to make a prediction on the downstream task", default="GNN") # GOODIE_best
    parser.add_argument("--pseudo_type", type=float, help="pseudo contrastive type", default=-1) # 'train', 'all_1', 'all_y', 'strong_1'
    parser.add_argument("--initial_filling", type=str, help="initial filling of missing values", default='None') # 'None', 'median', 'mode'

    if 'node2vec' in parser.parse_known_args()[0].embedder.lower():
        parser.add_argument("--walk_length", type=int, help="Control loss between metric loss and cross entropy loss", default=20) # 0.0
        parser.add_argument("--context_size", type=int, help="Control loss between metric loss and cross entropy loss", default=10) # 0.0
        parser.add_argument("--walks_per_node", type=int, help="Control loss between metric loss and cross entropy loss", default=10) # 0.0
        parser.add_argument("--p", type=float, help="Control loss between metric loss and cross entropy loss", default=1.0) # 0.0
        parser.add_argument("--q", type=float, help="Control loss between metric loss and cross entropy loss", default=1.0) # 0.0

    # if ('gnn' in parser.parse_known_args()[0].embedder.lower()) or ('GOODIE' in parser.parse_known_args()[0].embedder):
    parser.add_argument("--gnn",type=str,help="Type of model to make a prediction on the downstream task",default="GCN", choices=["SGC", "SAGE", "GCN", "GAT"])
    parser.add_argument("--scaled", type=str2bool, help="Wheter to utilize scaled PseudoCon loss, True for large datasets", default=False)
    parser.add_argument("--col_init", type=str2bool, help="Wheter to utilize scaled PseudoCon loss, True for large datasets", default=True)
    parser.add_argument("--max_only", type=int, help="Wheter to utilize scaled PseudoCon loss, True for large datasets", default=0)
    parser.add_argument("--ver", type=int, help="Wheter to utilize scaled PseudoCon loss, True for large datasets", default=0)

    # for label trick
    parser.add_argument("--label_trick", type=str2bool, help="Wheter to utilize scaled PseudoCon loss, True for large datasets", default=False)
    if parser.parse_known_args()[0].label_trick:
        parser.add_argument("--mask_rate", type=float, help="Control loss between metric loss and cross entropy loss", default=1.0) # 0.0
    parser.add_argument("--n_reuse", type=int, help="Control loss between metric loss and cross entropy loss", default=0) 
    parser.add_argument("--use_coef", type=str2bool, help="Wheter to utilize scaled PseudoCon loss, True for large datasets", default=False)

    parser.add_argument("--k", type=int, help="for knn-graph", default=10) 
    parser.add_argument("--epoch_start", type=int, help="# of initial hop in FP", default=0)
    parser.add_argument("--lp_temp", type=float, help="Temperature for Contrastive Learning", default=1.0)
    parser.add_argument("--include_neighbors", type=str2bool, help="add coefficient loss", default=False)
    parser.add_argument("--normalize_label_mat", type=str2bool, help="add coefficient loss", default=False)
    parser.add_argument("--attn_type", type=str, help="atteion type", default='mr') # mean(sum), random, concat
    parser.add_argument("--metric", type=str, help="encoder", default='gt')
    parser.add_argument("--pca", type=str2bool, help="PCA", default=False)
    parser.add_argument("--coef_loss", type=str2bool, help="add coefficient loss", default=False)
    parser.add_argument("--hop", type=int, help="# of initial hop in FP", default=0)
    parser.add_argument("--replace", type=str2bool, help="replace 1.0 in in hop in FP", default=False)
    parser.add_argument("--normalize_feature", type=str2bool, help="Normalize Feature Matrix before FP", default=False)
    # parser.add_argument("--p", type=float, help="Likelihood of immediately revisiting a node in the walk", default=1)
    # parser.add_argument("--q", type=float, help="Control parameter to interpolate between BFS and DFS", default=1)
    parser.add_argument("--model_1", type=str, help="encoder", default='gcn')
    parser.add_argument("--model_2", type=str, help="classifier", default='mlp')
    parser.add_argument("--design", type=int, help="Indexes of gpu to run program on", default=3)
    parser.add_argument("--leaky_alpha", type=float, help="Control slope of leaky relu", default=0.3)
    parser.add_argument("--temp", type=float, help="Temperature for Contrastive Learning", default=0.01)
    parser.add_argument("--autoscale", type=str2bool, help="replace 1.0 in in hop in FP", default=True)

    parser.add_argument("--print_result", type=int, help="Patience for early stopping", default=100)
    parser.add_argument("--patience_ogbn", type=int, help="Patience for early stopping", default=5)
    parser.add_argument("--patience", type=int, help="Patience for early stopping", default=200)
    parser.add_argument("--lr", type=float, help="Learning Rate", default=0.005)
    parser.add_argument("--epochs", type=int, help="Max number of epochs", default=10000)
    parser.add_argument("--n_runs", type=int, help="Max number of runs", default=10)
    parser.add_argument("--hidden_dim", type=int, help="Hidden dimension of model", default=64)
    parser.add_argument("--num_layers", type=int, help="Number of GNN layers", default=2)
    parser.add_argument("--num_heads", type=int, help="Number of GAT heads", default=2)
    parser.add_argument("--num_iterations", type=int, help="Number of diffusion iterations for feature reconstruction", default=40)
    parser.add_argument("--num_iterations_lp", type=int, help="Number of diffusion iterations for feature reconstruction", default=40)
    parser.add_argument("--dropout", type=float, help="Feature dropout", default=0.5)
    parser.add_argument("--jk", action="store_true", help="Whether to use the jumping knowledge scheme")
    parser.add_argument("--batch_size", type=int, help="Batch size for models trained with neighborhood sampling", default=1024)
    parser.add_argument("--batch_norm",help="Applying Batch Normalizetion",action="store_true",default=True)
    parser.add_argument("--graph_sampling",help="Set if you want to use graph sampling (always true for large graphs)",action="store_true",default=False)
    parser.add_argument("--node_sampling",help="Set if you want to use node sampling",action="store_true",default=True)
    parser.add_argument("--homophily", type=float, help="Level of homophily for synthetic datasets", default=None)
    parser.add_argument("--log", type=str, help="Log Level", default="INFO", choices=["DEBUG", "INFO", "WARNING"])

    if parser.parse_known_args()[0].embedder == 'TWIRLS':
        # model
        parser.add_argument("--norm"     , type = str  , default = "none")
        parser.add_argument("--attn_dropout" , type = float, default = 0.0)
        parser.add_argument("--inp_dropout"  , type = float, default = 0.0)
        parser.add_argument("--learn_emb"    , type = int  , default = 0)
        parser.add_argument("--lam"     , type = int  , default = 1.0)

        parser.add_argument("--mlp_bef"  , type = int  , default = 1)
        parser.add_argument("--mlp_aft"  , type = int  , default = 0)

        # propagation
        parser.add_argument("--no_precond", type = str2bool, default = False)
        parser.add_argument("--prop_step", type = int  , default = 8)  # 2
        parser.add_argument("--alp", type = float, default = 1)  # 0 for alpha = 1 / (1 + lambda)

        # attention
        parser.add_argument("--attention"    , action = "store_true" , default = False)
        parser.add_argument("--attn_p", type = float, default = 1)
        parser.add_argument("--use_eta", action = "store_true" , default = False)
        parser.add_argument("--tau", type = float, default = 1) 
        parser.add_argument("--T", type = float, default = 0)
        parser.add_argument("--attn_bef", action = "store_true" , default = False)









    return parser