import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='E3Diffusion')
    parser.add_argument('-f')
    parser.add_argument('--exp_name', type=str, default='debug_10')
    parser.add_argument('--model', type=str, default='egnn_dynamics',
                        help='our_dynamics | schnet | simple_dynamics | '
                            'kernel_dynamics | egnn_dynamics |gnn_dynamics')
    parser.add_argument('--probabilistic_model', type=str, default='diffusion',
                        help='diffusion')
    parser.add_argument('--joint_training', type=eval, default=False,
                        help='whether to train the diffusion model and the edge model jointly')

    # Training complexity is O(1) (unaffected), but sampling complexity is O(steps).
    parser.add_argument('--diffusion_steps', type=int, default=500)
    parser.add_argument('--diffusion_noise_schedule', type=str, default='polynomial_2',
                        help='learned, cosine')
    parser.add_argument('--diffusion_noise_precision', type=float, default=1e-5,
                        )
    parser.add_argument('--diffusion_loss_type', type=str, default='l2',
                        help='vlb, l2')

    parser.add_argument('--n_epochs', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--brute_force', type=eval, default=False,
                        help='True | False')
    parser.add_argument('--actnorm', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--break_train_epoch', type=eval, default=False,
                        help='True | False')
    parser.add_argument('--dp', type=eval, default=False,
                        help='True | False')
    parser.add_argument('--condition_time', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--clip_grad', type=eval, default=True,
                        help='True | False')
    parser.add_argument('--trace', type=str, default='hutch',
                        help='hutch | exact')
    # EGNN args -->
    parser.add_argument('--n_layers', type=int, default=6,
                        help='number of layers')
    parser.add_argument('--inv_sublayers', type=int, default=1,
                        help='number of layers')
    parser.add_argument('--nf', type=int, default=128,
                        help='number of layers')
    parser.add_argument('--tanh', type=eval, default=True,
                        help='use tanh in the coord_mlp')
    parser.add_argument('--attention', type=eval, default=True,
                        help='use attention in the EGNN')
    parser.add_argument('--norm_constant', type=float, default=1,
                        help='diff/(|diff| + norm_constant)')
    parser.add_argument('--sin_embedding', type=eval, default=False,
                        help='whether using or not the sin embedding')
    # <-- EGNN args
    parser.add_argument('--ode_regularization', type=float, default=1e-3)
    parser.add_argument('--dataset', type=str, default='qm9',
                        help='qm9 | zinc250k | qm9_second_half (train only on the last 50K samples of the training dataset)')
    parser.add_argument('--datadir', type=str, default='data/',
                        help='data directory')
    parser.add_argument('--filter_n_atoms', type=int, default=None,
                        help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
    parser.add_argument('--dequantization', type=str, default='argmax_variational',
                        help='uniform | variational | argmax_variational | deterministic')
    parser.add_argument('--n_report_steps', type=int, default=1)
    parser.add_argument('--wandb_usr', type=str)
    parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
    parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='enables CUDA training')
    parser.add_argument('--save_model', type=eval, default=True,
                        help='save model')
    parser.add_argument('--generate_epochs', type=int, default=1,
                        help='save model')
    parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
    parser.add_argument('--test_epochs', type=int, default=10)
    parser.add_argument('--data_augmentation', type=eval, default=False, help='use attention in the EGNN')
    parser.add_argument("--conditioning", nargs='+', default=[],
                        help='arguments : homo | lumo | alpha | gap | mu | Cv' )
    parser.add_argument('--resume', type=str, default=None,
                        help='')
    parser.add_argument('--start_epoch', type=int, default=0,
                        help='')
    parser.add_argument('--ema_decay', type=float, default=0.999,
                        help='Amount of EMA decay, 0 means off. A reasonable value'
                            ' is 0.999.')
    parser.add_argument('--augment_noise', type=float, default=0)
    parser.add_argument('--n_stability_samples', type=int, default=500,
                        help='Number of samples to compute the stability')
    parser.add_argument('--normalize_factors', type=eval, default=[1, 4, 1],
                        help='normalize factors for [x, categorical, integer]')
    parser.add_argument('--remove_h', action='store_true')
    parser.add_argument('--include_charges', type=eval, default=True,
                        help='include atom charge or not')
    parser.add_argument('--visualize_every_batch', type=int, default=1e8,
                        help="Can be used to visualize multiple times per epoch")
    parser.add_argument('--normalization_factor', type=float, default=1,
                        help="Normalize the sum aggregation of EGNN")
    parser.add_argument('--aggregation_method', type=str, default='sum',
                        help='"sum" or "mean"')

    # pp_model args -->
    # ALL args related to the pp_model end with _pp
    parser.add_argument('--n_layers_pp', type=int, default=4,
                        help='number of Equivariant blocks')
    parser.add_argument('--inv_sublayers_pp', type=int, default=2,
                        help='number of GCL layers in each Equivariant block')
    parser.add_argument('--hidden_nf_pp', type=int, default=64,
                        help='number of layers')
    parser.add_argument('--tanh_pp', type=eval, default=False,
                        help='use tanh in the coord_mlp')
    parser.add_argument('--attention_pp', type=eval, default=True,
                        help='use attention in the EGNN')
    parser.add_argument('--norm_constant_pp', type=float, default=1,
                        help='diff/(|diff| + norm_constant)')
    parser.add_argument('--sin_embedding_pp', type=eval, default=False,
                        help='whether using or not the sin embedding')
    parser.add_argument('--normalization_factor_pp', type=float, default=1,
                        help="Normalize the sum aggregation of EGNN")
    parser.add_argument('--aggregation_method_pp', type=str, default='sum',
                        help='"sum" or "mean"')
    parser.add_argument('--encoder_pp', type=str, default='egnn',
                        help='egnn or None')
    parser.add_argument('--edge_head_pp', type=str, default='mlp',
                        help='mlp or linear')
    parser.add_argument('--edge_head_hidden_dim_pp', type=int, default=32,
                        help='hidden dimensions of the edge head MLP')
    parser.add_argument('--modify_h_pp', action='store_true', default=False,
                        help='If True, the edge model will also modify the atom types and charges')
    parser.add_argument('--lambda_pp_loss', type=float, default=1.0,
                        help='coefficient for the pp_model loss in the joint training')
    # <-- pp_model args

    args = parser.parse_args()
    return args
