import argparse

parser = argparse.ArgumentParser()

# about seed and basic info
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--runseed', type=int, default=0)
parser.add_argument('--device', type=int, default=0)

# about dataset and dataloader
parser.add_argument('--input_data_dir', type=str, default='')
parser.add_argument('--dataset', type=str, default='bace')
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--data_dir_chirality', type=str)
parser.set_defaults(
    data_dir_chirality='../datasets/chirality/d4_docking/d4_docking_rs.csv')
parser.add_argument('--PCQM4M_sample_size', type=int, default=-1)

# about training strategies
parser.add_argument('--split', type=str, default='scaffold')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lr_scale', type=float, default=1)
parser.add_argument('--decay', type=float, default=0)
parser.add_argument('--split_path', type=str)  # used in chirality
parser.set_defaults(
    split_path='../datasets/chirality/d4_docking/rs/split0.npy')
# about molecule GNN
parser.add_argument('--gnn_type', type=str, default='gin')
parser.add_argument('--num_layer', type=int, default=5)
parser.add_argument('--emb_dim', type=int, default=300)
parser.add_argument('--dropout_ratio', type=float, default=0.5)
parser.add_argument('--graph_pooling', type=str, default='mean')
parser.add_argument('--JK', type=str, default='last')
parser.add_argument('--gnn_lr_scale', type=float, default=1)
parser.add_argument('--model_3d', type=str, default='schnet',
                    choices=['schnet', 'egnn', 'se3', 'spherenet'])

# for AttributeMask
parser.add_argument('--mask_rate', type=float, default=0.15)
parser.add_argument('--mask_edge', type=int, default=0)

# for ContextPred
parser.add_argument('--csize', type=int, default=3)
parser.add_argument('--contextpred_neg_samples', type=int, default=1)

# for EGNN
parser.add_argument('--m_dim', type=int, default=50)
parser.add_argument('--engg_n_layers', type=int, default=4)

# for SchNet
parser.add_argument('--num_filters', type=int, default=128)
parser.add_argument('--num_interactions', type=int, default=6)
parser.add_argument('--num_gaussians', type=int, default=51)
parser.add_argument('--cutoff', type=float, default=10)
parser.add_argument('--readout', type=str, default='mean',
                    choices=['mean', 'add'])
parser.add_argument('--schnet_lr_scale', type=float, default=1)

# for SE(3)-Transformer
parser.add_argument("--se3_transformer_num_layers", type=int, default=7)
parser.add_argument("--se3_transformer_num_channels", type=int, default=32)
parser.add_argument("--se3_transformer_num_degrees", type=int, default=4)
parser.add_argument("--se3_transformer_num_nlayers", type=int, default=1)
parser.add_argument("--se3_transformer_div", type=int, default=2)
parser.add_argument("--se3_transformer_n_heads", type=int, default=8)

# for SphereNet
parser.add_argument("--spherenet_cutoff", type=float, default=3.0)
parser.add_argument("--spherenet_num_layers", type=int, default=4)
parser.add_argument("--spherenet_int_emb_size", type=int, default=64)
parser.add_argument("--spherenet_basis_emb_size_dist", type=int, default=8)
parser.add_argument("--spherenet_basis_emb_size_angle", type=int, default=8)
parser.add_argument("--spherenet_basis_emb_size_torsion", type=int, default=8)
parser.add_argument("--spherenet_out_emb_channels", type=int, default=256)
parser.add_argument("--spherenet_num_spherical", type=int, default=3)
parser.add_argument("--spherenet_num_radial", type=int, default=6)
parser.add_argument("--spherenet_envelope_exponent", type=int, default=5)
parser.add_argument("--spherenet_num_before_skip", type=int, default=1)
parser.add_argument("--spherenet_num_after_skip", type=int, default=2)
parser.add_argument("--spherenet_num_output_layers", type=int, default=3)

# for 2D-3D Contrastive CL
parser.add_argument('--CL_neg_samples', type=int, default=1)
parser.add_argument('--CL_similarity_metric', type=str, default='InfoNCE_dot_prod',
                    choices=['InfoNCE_dot_prod', 'EBM_dot_prod'])
parser.add_argument('--T', type=float, default=0.1)
parser.add_argument('--normalize', dest='normalize', action='store_true')
parser.add_argument('--no_normalize', dest='normalize', action='store_false')
parser.add_argument('--SSL_masking_ratio', type=float, default=0)
parser.add_argument('--AE_model', type=str, default='AE', choices=[
    'AE',
    'VAE', 'Flow_VAE', 'IWAE',
    'Flow_Exact_VAE',
    'Energy_VAE'
])
parser.set_defaults(AE_model='AE')

# for 2D-3D AutoEncoder
parser.add_argument('--AE_loss', type=str, default='l2', choices=['l1', 'l2', 'cosine'])
parser.add_argument('--detach_target', dest='detach_target', action='store_true')
parser.add_argument('--no_detach_target', dest='detach_target', action='store_false')
parser.set_defaults(detach_target=True)

# for 2D-3D Variational AutoEncoder
parser.add_argument('--beta', type=float, default=1)

# for 2D-3D Variational AutoEncoder with Flow
parser.add_argument('--flow_model', type=str, default='planar',
                    choices=['planar', 'radial', 'mlp'])
parser.add_argument('--flow_length', type=int, default=8)

# for Importance Weighted AutoEncoder
parser.add_argument('--iw_samples', type=int, default=5)

# for 2D-3D Contrastive CL and AE/VAE
# parser.add_argument('--alpha', type=float, default=1)
parser.add_argument('--alpha_1', type=float, default=1)
parser.add_argument('--alpha_2', type=float, default=1)

# for 2D SSL and 3D-2D SSL
parser.add_argument('--SSL_2D_mode', type=str, default='AM')
parser.add_argument('--alpha_3', type=float, default=0.1)
parser.add_argument('--gamma_joao', type=float, default=0.1)
parser.add_argument('--gamma_joaov2', type=float, default=0.1)

# about if we would print out eval metric for training data
parser.add_argument('--eval_train', dest='eval_train', action='store_true')
parser.add_argument('--no_eval_train', dest='eval_train', action='store_false')
parser.set_defaults(eval_train=True)

# about loading and saving
parser.add_argument('--input_model_file', type=str, default='')
parser.add_argument('--output_model_dir', type=str, default='')

# verbosity
parser.add_argument('--verbose', dest='verbose', action='store_true')
parser.add_argument('--no_verbose', dest='verbose', action='store_false')
parser.set_defaults(verbose=False)

args = parser.parse_args()
print('arguments\t', args)
