import argparse

def get_args():
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--task', type = str, default='GraphTask')
    parser.add_argument('--dataset_name', type=str, default='Cora',help='Choose the dataset of pretrainor downstream task')
    parser.add_argument('--device', type=int, default=0,
                        help='Which gpu to use if any (default: 0)')
    parser.add_argument('--gnn_type', type=str, default="GCN", help='We support gnn like \GCN\ \GAT\ \GT\ \GCov\ \GIN\ \GraphSAGE\, please read ProG.model module')
    parser.add_argument('--pretrain_method', '--pm', type=str, default='GraphCL',help='Which pretrained method was used for the GNN model')
    parser.add_argument('--prompt_type', type=str, default='None', 
                        help='Choose the prompt type for node or graph task, for node task,we support \GPPT\, \All-in-one\, \Gprompt\ for graph task , \All-in-one\, \Gprompt\, \GPF\, \GPF-plus\ ')
    parser.add_argument('--hid_dim', type=int, default=128,
                        help='hideen layer of GNN dimensions (default: 300)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='Input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=50,
                        help='Number of epochs to train (default: 50)')
    parser.add_argument('--shot_num', type=int, default = 5, help='Number of shots')
    parser.add_argument('--pre_train_model_path', type=str, default='None', 
                        help='add pre_train_model_path to the downstream task, the model is self-supervise model if the path is None and prompttype is None.')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='Learning rate (default: 0.0001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='Weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=3,
                        help='Number of GNN message passing layers (default: 3).')

    parser.add_argument('--dropout_ratio', type=float, default=0.5,
                        help='Dropout ratio (default: 0.5)')
    parser.add_argument('--graph_pooling', type=str, default="mean",
                        help='Graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='How the node features across layers are combined. last, sum, max or concat')

    parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting dataset.")
    parser.add_argument('--runseed', type=int, default=0, help = "Seed for running experiments.")
    parser.add_argument('--num_workers', type=int, default = 0, help='Number of workers for dataset loading')
    parser.add_argument('--num_layers', type=int, default = 1, help='A range of [1,2,3]-layer MLPs with equal width')
    parser.add_argument('--pnum', type=int, default = 5, help='The number of independent basis for GPF-plus')
    
    # store true
    parser.add_argument('--supervised', action='store_true', help='Whether to finetune the model')
    parser.add_argument('--finetune', action='store_true', help='Whether to finetune the model')

    # transfer
    parser.add_argument('--transfer', action='store_true', help='Transfer from which dataset')
    parser.add_argument('--source', '--s', type=str, default='MUTAG',help='Choose the dataset of pretrainor downstream task')
    parser.add_argument('--feature_dim', type=int, default=3, help='Feature dimension for transfer learning')

    args = parser.parse_args()
    return args




def get_morpher_args():
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--dataset_name', '--d', type=str, default='MUTAG',help='Choose the dataset of pretrainor downstream task')
    parser.add_argument('--device', type=int, default=0,
                        help='Which gpu to use if any (default: 0)')
    parser.add_argument('--gnn_type', '--g', type=str, default="GCN", help='We support gnn like \GCN\ \GAT\ \GT\ \GCov\ \GIN\ \GraphSAGE\, please read ProG.model module')
    parser.add_argument('--pretrain_method', '--pm', type=str, default='GraphCL',help='Which pretrained method was used for the GNN model')
    parser.add_argument('--prompt_type', type=str, default="Morpher", 
                        help='Choose the prompt type for node or graph task, for node task,we support \GPPT\, \All-in-one\, \Gprompt\ for graph task , \All-in-one\, \Gprompt\, \GPF\, \GPF-plus\ ')
    parser.add_argument('--pre_train_model_path', type=str, default='None', 
                        help='add pre_train_model_path to the downstream task, the model is self-supervise model if the path is None and prompttype is None.')
    
    parser.add_argument('--task', type = str, default='graph')
    parser.add_argument('--source', '--s', type=str, default='MUTAG',help='Choose the dataset of pretrainor downstream task')

    args = parser.parse_args()
    return args