from prompt_graph.pretrain import Edgepred_GPPT, SimGRACE, GraphMAE, DGI
from prompt_graph.utils import seed_everything
from prompt_graph.utils import mkdir, get_args
args = get_args()
seed_everything(args.seed)

if __name__ == '__main__':
    print('Dataset: {}, Pretrain: {}, GNN: {}, Seed: {}'.format(args.dataset_name, args.task, args.gnn_type, args.seed))
    if args.task == 'SimGRACE':
        pt = SimGRACE(graph_list=None, input_dim=None, dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device, seed=args.seed)
    if args.task == 'Edgepred_GPPT':
        pt = Edgepred_GPPT(graph_list=None, input_dim=None, dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device, seed=args.seed)
    if args.task == 'DGI':
        pt = DGI(graph_list=None, input_dim=None, dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device, seed=args.seed)
    if args.task == 'GraphMAE':
        pt = GraphMAE(graph_list=None, input_dim=None, dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device,
                    mask_rate=0.75, drop_edge_rate=0.0, replace_rate=0.1, loss_fn='sce', alpha_l=2, seed=args.seed)
    pt.pretrain()

