import argparse


def args_parser():
    parser = argparse.ArgumentParser()
    # exp
    parser.add_argument("--exp_name", default="run", type=str,
                        help="Experiment name")
    parser.add_argument("--dump_path", default="dump/", type=str,
                        help="Experiment dump path")
    # parser.add_argument("--dump_path", default="checkpoint/GOODHIV-scaffold-covariate.pkl", type=str,
    #                     help="Experiment dump path")
    parser.add_argument("--exp_id", default="", type=str,
                        help="Experiment ID")
    parser.add_argument("--gpu", default='2', type=str)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument("--load_path", default=None, type=str)

    # dataset
    parser.add_argument("--data_root", default='data', type=str)
    parser.add_argument("--config_path", default='configs', type=str)
    parser.add_argument("--dataset", default='GOODZINC', type=str)
    parser.add_argument("--domain", default='scaffold', type=str)
    parser.add_argument("--shift", default='concept', type=str)

    # Encoder
    parser.add_argument("--emb_dim", default=128, type=int)
    parser.add_argument("--layer", default=4, type=int)
    parser.add_argument("--dropout", default=0, type=float)
    parser.add_argument("--gnn_type", default='gin', type=str, choices=['gcn', 'gin'])
    parser.add_argument("--pooling_type", default='mean', type=str)

    # Semantic Prototype Module
    parser.add_argument("--num_prototypes", default=20, type=int,
                        help="Number of semantic prototypes K")
    parser.add_argument("--prototype_temperature", default=3.0, type=float,
                        help="Temperature parameter tau for prototype assignment")

    # Graph-level Semantic Module
    parser.add_argument("--top_k", default=8, type=int,
                        help="Top-k neighbors for inter-molecular semantic alignment")
    parser.add_argument("--inter_temperature", default=1.0, type=float,
                        help="Temperature parameter rho for inter-molecular weights")

    # Contrastive Learning
    parser.add_argument("--proj_dim", default=None, type=int,
                        help="Projection dimension for contrastive learning (default: emb_dim)")
    parser.add_argument("--contrastive_temperature", default=0.08, type=float,
                        help="Temperature parameter gamma for InfoNCE loss")

    # Adversarial Perturbation (Inner Maximization)
    parser.add_argument("--epsilon", default=0.12, type=float,
                        help="Perturbation bound epsilon")
    parser.add_argument("--inner_steps", default=3, type=int,
                        help="Number of steps T for inner maximization")
    parser.add_argument("--inner_lr", default=0.1, type=float,
                        help="Learning rate eta_delta for inner maximization")

    # Loss Weights
    parser.add_argument("--lambda_intra", default=0.002, type=float,
                        help="Weight lambda_intra for intra-molecular semantic consistency")
    parser.add_argument("--lambda_inter", default=0.01, type=float,
                        help="Weight lambda_inter for inter-molecular semantic alignment")
    parser.add_argument("--lambda_inv", default=0.1, type=float,
                        help="Weight lambda_inv for semantic invariance loss (L_inv)")
    parser.add_argument("--lambda_mu", default=0.02, type=float,
                        help="Weight lambda_mu for prototype regularization in L_intra")
    parser.add_argument("--lambda_orth", default=0, type=float,
                        help="Weight lambda_orth for prototype orthogonal constraint (optional)")

    # Warm-up Strategy
    parser.add_argument("--warmup_epochs", default=10, type=int,
                        help="Number of warm-up epochs before starting adversarial training")
    parser.add_argument("--warmup_contrastive", action='store_true', default=False,
                        help="Whether to use progressive warm-up for contrastive loss (gradually increase weight)")

    # Training
    parser.add_argument("--lr", default=0.001, type=float)
    parser.add_argument("--bs", default=256, type=int)
    parser.add_argument("--epoch", default=200, type=int)

    args = parser.parse_args()

    return args
