import argparse
from xmlrpc.client import boolean
import torch
import os

parser = argparse.ArgumentParser(description='QGT')


# 1.Dataset
parser.add_argument('--dataset', type=str, default='cora', help='dataset name')
parser.add_argument('--datapath', type=str, default='', help='path to dataset')
parser.add_argument('--use_feats', action='store_true', help='use initial features or not')
parser.add_argument('--normalize_feats', action='store_false', help='normalize initial features or not')
parser.add_argument('--split_seed', type=int, default=1000, help='seed for split dataset')
parser.add_argument('--number_sample_A_set', type=int, default=50, help='number of samples in set A')
parser.add_argument('--number_sample_neighbors', type=int, default=50, help='number of samples in neighbor')
parser.add_argument('--num_nodes', type=int, default=-1, help='num of nodes')

# 2. Experiment Setup (Training/Testing)

parser.add_argument('--task', type=str, default='nc', help='nc or lrbm')
parser.add_argument('--device', type=str, default='cpu', help='training device')
parser.add_argument('--device_id', type=str, default='0', help='device id for gpu')
parser.add_argument('--seed', type=int, default=998877, help='training/model_initializing seed')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight for L2 loss on basic models.')
parser.add_argument('--weight_decay_2', type=float, default=5e-4, help='weight for L2 loss on gnn.')
parser.add_argument('--weight_decay_3', type=float, default=5e-4, help='weight for L2 loss on other.')
parser.add_argument('--epochs', type=int, default=1000, help='training epoch')
parser.add_argument('--patience', type=int, default=100, help='patience for early stop')
parser.add_argument('--min_epoch', type=int, default=100, help='min epoch')
parser.add_argument('--using_riemannianAdam', type=bool, default=False, help='use RiemannianAdam or standard Adam')
parser.add_argument('--using_pretrained_feat', type=bool, default=False, help='use pretrained_feat')
parser.add_argument('--log_freq', type=int, default=1, help='how often to print train/val metrics (in epochs)')
parser.add_argument('--eval_freq', type=int, default=1, help='how often to compute val metrics (in epochs)')
parser.add_argument('--test_freq', type=int, default=1, help='how often to compute test metrics (in epochs)')
parser.add_argument('--emb_save', type=bool, default=True, help='save embedding or not')
parser.add_argument('--model_save', type=bool, default=True, help='save model checkpoint or not')
parser.add_argument('--step_lr', type=int, default=1000, help='reduce lr after step_lr epochs')
parser.add_argument('--gamma', type=float, default=0.8, help='gamma to reduce lr after step_lr epochs')

# 3. Model configs

parser.add_argument('--model', type=str, default='GCN', help='model name')
parser.add_argument('--nfeat', type=int, default=128, help='dim of input feature')
parser.add_argument('--nhid', type=int, default=16, help='dim of hidden embedding')
parser.add_argument('--nout', type=int, default=16, help='dim of output embedding')
parser.add_argument('--dropout', type=float, default=0.0, help='dropout rate (1 - keep probability).')
parser.add_argument('--graph_dropout', type=float, default=0.0, help='dropout rate (1 - keep probability).')
parser.add_argument('--last_dropout', type=float, default=0.0, help='dropout rate (1 - keep probability).')
parser.add_argument('--bias', action='store_false', help='use bias or not')
parser.add_argument('--act', type=str, default='leakyrelu', choices=['relu', 'sigmoid', 'tanh', 'leakyrelu', 'gelu', 'elu'], help="Activation function to use")
parser.add_argument('--pos_weight', type=bool, default=False, help='use weight for positive samples or not')
parser.add_argument('--curvature', type=float, default=1.0, help='curvature for non-Euclidean models')
parser.add_argument('--curvature_out', type=float, default=3.0, help='for Hypformer')
parser.add_argument('--prod_manifold_e', type=int, default=0, help='product manifold, number of dimensions of Euclidean manifold')
parser.add_argument('--prod_manifold_s', type=int, default=8, help='product manifold, number of dimensions of Spherical manifold')
parser.add_argument('--prod_manifold_h', type=int, default=8, help='product manifold, number of dimensions of Hyperbolic manifold')
parser.add_argument('--time_dim', type=int, default=8, help='time dimension in the pseudo hyperboloid')
parser.add_argument('--space_dim', type=int, default=8, help='space dimension in the pseudo hyperboloid')
parser.add_argument('--beta', type=float, default=-1.0, help='curvature the pseudo hyperboloid')
parser.add_argument('--attention_type', type=str, default='linear_focused', help='linear_focused or full')
parser.add_argument('--power_k', type=int, default=2, help='for Hypformer')
parser.add_argument('--add_positional_encoding', type=bool, default=True, help='use positional encoding')
parser.add_argument('--trans_num_layers', type=int, default=2, help='number of transformer layer')
parser.add_argument('--graph_num_layers', type=int, default=2, help='number of transformer layer')
parser.add_argument('--trans_num_heads', type=int, default=4, help='number of transformer head')
parser.add_argument('--trans_use_bn', action='store_false', help='for Hypformer')
parser.add_argument('--trans_use_residual', action='store_false', help='for Hypformer')
parser.add_argument('--trans_use_weight', action='store_false', help='for Hypformer')
parser.add_argument('--trans_use_act', action='store_false', help='for Hypformer')
parser.add_argument('--use_hyperdecoder', action='store_true', help='use hyperdecoder')
parser.add_argument('--trans_heads_concat', type=bool, default=True, help='for Hypformer, QGT')
parser.add_argument('--graph_weight', type=float, default=0.8, help='graph encoder weight in QGT')
parser.add_argument('--use_graph', action='store_false', help='use graph encoder in QGT')
parser.add_argument('--use_pe', action='store_false', help='use positional encoding in QGT')
parser.add_argument('--lamda', type=float, default=0.1, help='weight for edge reg loss')
parser.add_argument('--alpha', type=float, default=0.5, help='balance add in QGT')
parser.add_argument('--rb_order', type=int, default=2, help='for NodeFormer')
parser.add_argument('--dropout_time', type=float, default=0.0, help='dropout rate (1 - keep probability).')
parser.add_argument('--dropout_space', type=float, default=0.0, help='dropout rate (1 - keep probability).')
parser.add_argument('--g_dropout_time', type=float, default=0.0, help='dropout rate (1 - keep probability).')
parser.add_argument('--g_dropout_space', type=float, default=0.0, help='dropout rate (1 - keep probability).')

args = parser.parse_args()

if int(args.device_id) >= 0 and torch.cuda.is_available():
    args.device = torch.device("cuda".format(args.device_id))
else:
    args.device = torch.device("cpu")