import argparse


def csv_list(string):
    return string.split(',')


def get_args_common(parser):
    parser.add_argument('--wandb', type=bool, default=False)

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--gpu', type=int, default=0)

    # Path Parameters
    parser.add_argument('--data_path', type=str, default='/root/autodl-tmp/cache_data')
    parser.add_argument("--model_path", type=str, default='ckpts/pretrain_model')
    parser.add_argument("--root_path", type=str, default='/root/autodl-tmp')

    # Graph Token Parameters
    parser.add_argument('--graph_llm_name', type=str, default='ST')
    parser.add_argument("--llm_b_size", type=int, default=1)

    parser.add_argument('--pretrain_dataset', '--pt_data', type=str, default='all')
    parser.add_argument('--pretrain_epochs', '--pt_epochs', '--epochs', type=int, default=5)

    parser.add_argument('--num_expert', type=int, default=7)
    parser.add_argument('--codebook_size', type=int, default=128)
    parser.add_argument('--codebook_heads', type=int, default=1)
    parser.add_argument('--topk', type=int, default=2)
    
    # Encoder Parameters
    parser.add_argument("--input_dim", type=int, default=768)
    parser.add_argument("--hidden_dim", type=int, default=768)
    parser.add_argument('--num_neighbors', '--neighbors', type=int, default=10)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--activation", '--act', type=str, default="relu")
    parser.add_argument('--normalize', type=str, default="batch", choices=['none', 'batch', 'layer'])
    parser.add_argument('--dropout', type=float, default=0.15)

    return parser


def get_args_pretrain():
    parser = argparse.ArgumentParser('Pretrain')
    get_args_common(parser)

    # VQ Parameters
    parser.add_argument('--kmeans_init', type=bool, default=False)

    # Pretrain Dataset
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--pretrain_weight_decay', '--pt_decay', '--decay', type=float, default=1e-5)
    parser.add_argument('--pretrain_batch_size', '--pt_batch', type=int, default=1024)
    parser.add_argument('--feat_p', type=float, default=0.1)
    parser.add_argument('--edge_p', type=float, default=0.1)
    parser.add_argument('--topo_recon_ratio', type=float, default=0.1)
    
    parser.add_argument('--feat_lambda', type=float, default=100)
    parser.add_argument('--topo_lambda', type=float, default=0.01)
    parser.add_argument('--field_lambda', type=float, default=0.01)
    parser.add_argument('--contrastive_lambda', type=float, default=1e-3)
    parser.add_argument('--use_schedular', type=bool, default=True)

    args = parser.parse_args()
    return vars(args)


def get_args_finetune():
    parser = argparse.ArgumentParser('Finetune')
    get_args_common(parser)

    parser.add_argument("--n_trials", type=int, default=50)

    # General Parameters
    parser.add_argument("--setting", type=str, default="few_shot")  # "standard", "few_shot", "zero_shot", "in_context"

    # Path Parameters
    parser.add_argument("--GQA_eval_path", type=str, default='cache_GQA_eval')

    # Few-shot Parameters
    parser.add_argument("--n_task", type=int, default=20)
    parser.add_argument("--n_way", type=int, default=5)
    parser.add_argument("--n_train", type=int, default=1)
    parser.add_argument("--n_shot", type=int, default=3)
    parser.add_argument("--n_query", type=int, default=3)

    # Fine-Tune Parameters
    parser.add_argument("--finetune_dataset", type=str, default="WN18RR")
    parser.add_argument("--repeat", type=int, default=5)
    parser.add_argument("--finetune_epochs", type=int, default=1000)
    parser.add_argument("--early_stop", type=int, default=200)
    parser.add_argument("--train_batch_size", type=int, default=0)
    parser.add_argument("--eval_batch_size", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--query_node_code_first", action="store_true", help="Use node code to form edge/graph code")

    # Model Parameters
    parser.add_argument("--use_z_in_predict", type=bool, default=True)
    parser.add_argument("--use_cosine_sim", type=bool, default=True)
    parser.add_argument("--lambda_proto", type=float, default=0.1)
    parser.add_argument("--lambda_act", type=float, default=1)
    parser.add_argument("--trade_off", type=float, default=0.1)
    parser.add_argument("--num_instances_per_class", type=int, default=0)
    parser.add_argument('--no_lin_clf', type=bool, default=False)
    parser.add_argument('--no_proto_clf', type=bool, default=True)

    # GQA Parameters
    # LLM related
    parser.add_argument("--llm_model_name", type=str, default='st')
    parser.add_argument("--llm_model_path", type=str, default='')
    parser.add_argument("--llm_frozen", type=str, default='True')
    parser.add_argument("--llm_num_virtual_tokens", type=int, default=10)
    parser.add_argument("--output_dir", type=str, default='output')
    parser.add_argument("--max_txt_len", type=int, default=512)
    parser.add_argument("--max_new_tokens", type=int, default=32)
    parser.add_argument("--max_memory", type=csv_list, default=[80, 80])

    args = parser.parse_args()
    return vars(args)
