from networks.M3Dphormer import M3Dphormer
from networks.M3DphormerL import M3DphormerL

def parse_model(args, n, c, d, global_nodes, device):


    if args.dataset in ['ogbn-arxiv']:
        model = M3DphormerL(n_ori_nodes=n, n_cluster=args.num_clusters, n_global=global_nodes, x_dim=d, h_dim=args.h_dim, n_cls=c,
                            n_head=args.n_head, layers=args.layers, dropout=args.dropout, attn_dropout=args.attn_dropout, local_type=args.local_type,
                            learn_global=args.learn_global, use_cache=args.use_cache, use_res=args.use_res, norm_type=args.norm_type, norm_pos=args.norm_pos).to(device)
    else:
        model = M3Dphormer(n_ori_nodes=n, n_cluster=args.num_clusters, n_global=global_nodes, x_dim=d, h_dim=args.h_dim, n_cls=c,
                            n_head=args.n_head, layers=args.layers, dropout=args.dropout, attn_dropout=args.attn_dropout, local_type=args.local_type,
                            learn_global=args.learn_global, use_cache=args.use_cache, use_res=args.use_res, norm_type=args.norm_type, norm_pos=args.norm_pos).to(device)
    return model


def parser_add_main_args(parser):
    # dataset and evaluationx
    parser.add_argument('--dataset', type=str, default='amazon-computer')
    parser.add_argument('--data_dir', type=str, default='./data/')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--device', type=int, default=1)
    parser.add_argument('--epochs', type=int, default=2000,
                        help='traning epochs')
    parser.add_argument('--runs', type=int, default=5,
                        help='number of distinct runs')
    parser.add_argument('--metric', type=str, default='acc', choices=['acc', 'rocauc'],
                        help='evaluation metric')
    parser.add_argument('--split', type=str, default='random')

    parser.add_argument('--num_clusters', type=int, default=224)
    parser.add_argument('--global_nodes_per_class', type=int, default=1)
    parser.add_argument('--learn_global', action='store_true',
                        help='learnable global features')
    parser.add_argument('--h_dim', type=int, default=64)
    parser.add_argument('--layers', type=int, default=5,
                        help='number of layers for attention')
    parser.add_argument('--local_type', type=str, default='GAT',
                        choices=['GAT', 'GCN', 'GATv2', 'Trans'])
    parser.add_argument('--use_cache', action='store_true',
                        help='use cached norm-A for GCN')
    parser.add_argument('--n_head', type=int, default=4,
                        help='number of heads for attention')
    parser.add_argument('--use_res', action='store_true',
                        help='using residual connection')
    parser.add_argument('--norm_type', type=str,
                        default='rms', choices=['rms', 'ln', 'bn', 'no'])
    parser.add_argument('--norm_pos', type=str, default='pre',
                        choices=['pre', 'post', 'no'])
    
    # training
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--wd', type=float, default=1e-4)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--attn_dropout', type=float, default=0.3)
    parser.add_argument('--semi', action='store_true',
                        help='train model by semi-supervised learning')

    # display and utility
    parser.add_argument('--display_step', type=int,
                        default=1, help='how often to print')
    parser.add_argument('--save_model', action='store_true',
                        help='whether to save model')
    parser.add_argument('--postfix', type=str,
                        default='M3Dphormer')
    parser.add_argument('--model_dir', type=str,
                        default='./model/', help='where to save model')
    parser.add_argument('--save_result', action='store_true',
                        help='whether to save result')
