import argparse


def parse_args_llama():
    parser = argparse.ArgumentParser(description="G-Retriever")

    parser.add_argument("--model_name", type=str, default='graph_llm')
    parser.add_argument("--project", type=str, default="project_g_retriever")
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--dataset", type=str, default='expla_graphs')
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--wd", type=float, default=0.05)
    parser.add_argument("--patience", type=float, default=2)

    # Model Training
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--grad_steps", type=int, default=4)

    # Learning Rate Scheduler
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--warmup_epochs", type=float, default=1)

    # Inference
    parser.add_argument("--eval_batch_size", type=int, default=16)

    # LLM related
    parser.add_argument("--llm_model_name", type=str, default='7b')
    #parser.add_argument("--llm_model_path", type=str, default='')
    parser.add_argument("--llm_frozen", type=str, default='True')
    parser.add_argument("--llm_num_virtual_tokens", type=int, default=10)
    parser.add_argument("--output_dir", type=str, default='output')
    parser.add_argument("--max_txt_len", type=int, default=512)
    parser.add_argument("--max_new_tokens", type=int, default=32)

    # GNN related
    parser.add_argument("--gnn_model_name", type=str, default='gt')
    parser.add_argument("--gnn_num_layers", type=int, default=4)
    parser.add_argument("--gnn_in_dim", type=int, default=1024)
    parser.add_argument("--gnn_hidden_dim", type=int, default=1024)
    parser.add_argument("--gnn_num_heads", type=int, default=4)
    parser.add_argument("--gnn_dropout", type=float, default=0.15)
    # QAG related
    parser.add_argument("--num_graph_token", type=int, default=1)
    parser.add_argument("--edge_feature", type=str, default='True') 
    parser.add_argument("--query_aware", type=str, default='True')
    parser.add_argument("--pooling", type=str, default='graph_token')
    
    parser.add_argument("--adding_late_fusion", type=str, default='True')
    parser.add_argument("--num_late_fusion_layer", type=int, default=2)
    parser.add_argument("--late_fusion_dropout", type=float, default=0.15)
    

    args = parser.parse_args()
    
    if args.edge_feature == 'True':
        args.edge_feature = True
    elif args.edge_feature == 'False':
        args.edge_feature = False
    else:
        raise("--edge_feature must be 'True' or 'False'")
    
    if args.query_aware == 'True':
        args.query_aware = True
    elif args.query_aware == 'False':
        args.query_aware = False
    else:
        raise("--query_aware must be 'True' or 'False'")
    
    if args.adding_late_fusion == 'True':
        args.adding_late_fusion = True
    elif args.adding_late_fusion == 'False':
        args.adding_late_fusion = False
    else:
        raise("--adding_late_fusion must be 'True' or 'False'")
    
    return args
