import argparse
from os.path import join
import pickle

import torch


def get_args():
    parser = argparse.ArgumentParser(description='GeoLDM')

    # General args -->
    parser.add_argument('--exp_name', type=str, default='ldm_training')
    parser.add_argument('--train_diffusion', type=eval, default=True,
                        help='if True, train a second stage LDM, otherwise train the VAE model.')
    parser.add_argument('--train_regressor', type=eval, default=False,
                        help='if True, train a second stage regressor, otherwise train the VAE model.')
    parser.add_argument('--verbose', action='store_true', default=False, help='Log many things')
    parser.add_argument('--wandb_usr', type=str)
    parser.add_argument('--no_wandb', action='store_true', help='Disable wandb', default=False)
    parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='enables CUDA training')
    parser.add_argument('--debug', type=eval, default=False,
                        help='If True, use only few batches and evaluate on training set')
    parser.add_argument('--resume', type=str, default=None,
                        help='if given a path, resume training from that point')
    parser.add_argument('--start_epoch', type=int, default=0,
                        help='')
    parser.add_argument("--conditioning", nargs='+', default=[],
                        help='arguments : penalized_logP | morgan_fingerprint')
    parser.add_argument('--condition_dropout', action='store_true', default=False,
                        help='Whether to randomly drop conditioning values to do classifier-free guidance.')
    parser.add_argument('--use_ghost_nodes', action='store_true', default=False,
                        help='whether to use ghost nodes to remove the need to sample the number of nodes in advance.')
    parser.add_argument('--use_extra_atomic_features', action='store_true', default=False,
                        help='whether to use extra atom features such as ring information, valency etc.')
    parser.add_argument('--use_vocab_data', action='store_true', default=False,
                        help='whether to load subgraphs from the vocabulary.')
    parser.add_argument("--seed", type=int, default=0)

    # Dataset args -->
    parser.add_argument('--dataset', type=str, default='zinc250k',
                        help='qm9 | zinc250k | zinc250k_explicitH | qm9_second_half (train only on the last 50K samples of the training dataset, '
                            'helpful for conditional generation)')
    parser.add_argument('--datadir', type=str, default='data/',
                        help='data directory')
    parser.add_argument('--filter_n_atoms', type=int, default=None,
                        help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
    parser.add_argument('--remove_h', action='store_true')
    parser.add_argument('--include_atomic_numbers', type=eval, default=False,
                        help='If True, integer-valued atomic numbers will be part of the dataset')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
    parser.add_argument('--data_augmentation', type=eval, default=False, help='use data augmentation in the EGNN')
    parser.add_argument('--augment_noise', type=float, default=0)

    # Optimization args -->
    parser.add_argument('--n_epochs', type=int, default=1500)
    parser.add_argument('--patience', type=int, default=100000, metavar='N',
                        help='number of epochs to wait before stopping the training if val accuracy does not improve (default: no patience)')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--dp', type=eval, default=False,
                        help='Whether to use data parallel training')
    parser.add_argument('--clip_grad', type=eval, default=True,
                        help='Whether to perform gradient clipping before applying an optimizer step')
    parser.add_argument('--ema_decay', type=float, default=0.999,
                        help='Amount of EMA decay, 0 means off. A reasonable value is 0.999.')
    parser.add_argument('--ode_regularization', type=float, default=1e-3)
    parser.add_argument('--prodigyopt', type=eval, default=False,
                        help='Whether to use Prodigy as the optimizer. It is an adaptive parameter-free optimizer.')
    parser.add_argument('--prodigy_lr', type=float, default=1.0,
                        help='lr parameter to give for prodigy optimizer.')
    parser.add_argument('--prodigy_setting', type=int, default=-1)
    parser.add_argument('--d_coef', type=float, default=0.1,
                        help='d_coef parameter to give for prodigy optimizer.')

    # Training/Eval args -->
    parser.add_argument('--guacamaol_eval', type=eval, default=True,
                        help='If True, will run Guacamol evaluaion once 10k valid smiles have been generated.')
    parser.add_argument('--save_model', type=eval, default=True,
                        help='save model')
    parser.add_argument('--save_model_history', type=eval, default=False,
                        help='save model anew every new best val epoch')
    parser.add_argument('--test_epochs', type=int, default=10)
    parser.add_argument('--n_stability_samples', type=int, default=1000,
                        help='Number of samples to compute the stability, validity, etc.')
    parser.add_argument('--visualize_every_batch', type=int, default=1e8,
                        help="Can be used to visualize multiple times per epoch")

    # Common args for all models: diffusion dynamics, encoder, decoder -->
    parser.add_argument('--model', type=str, default='egnn_dynamics',
                        help='our_dynamics | schnet | simple_dynamics | '
                            'kernel_dynamics | egnn_dynamics |gnn_dynamics')
    parser.add_argument('--joint_training', type=eval, default=False,
                        help='whether to train the diffusion model and the decoder jointly')
    parser.add_argument('--joint_space', type=str, default=None,
                        help='the space on which both the diffusion model and the decoder operate') # setting it to None to make sure it's not used anymore
    parser.add_argument('--use_eps_correction', type=eval, default=False,
                        help='Whether to use epsilon correction while computing xh_pred for joint training')
    parser.add_argument('--lambda_joint_loss', type=float, default=None,
                        help='coefficient for the 2D reconstruction loss in the joint training')
    parser.add_argument('--ae_path', type=str, default=None,
                        help='Specify first stage VAE model path, for use in second stage training')
    parser.add_argument('--trainable_ae', type=eval, default=False,
                        help='Train first stage VAE model during second stage training')
    # following is for all different EGNN backbones across different models:
    parser.add_argument('--tanh', type=eval, default=True,
                        help='use tanh in the coord_mlp of all EGNNs')
    parser.add_argument('--attention', type=eval, default=True,
                        help='use attention in the EGNNs')
    parser.add_argument('--norm_constant', type=float, default=1,
                        help='diff/(|diff| + norm_constant)')
    parser.add_argument('--sin_embedding', type=eval, default=False,
                        help='whether using or not the sin embedding for all EGNNs')
    parser.add_argument('--normalization_factor', type=float, default=1,
                        help="Normalize the sum aggregation of EGNN")
    parser.add_argument('--aggregation_method', type=str, default='sum',
                        help='"sum" or "mean"')
    parser.add_argument('--encode_prop', action='store_true', default=False,
                        help='Whether to use sine and cosine ecnodings to encode the target props in the cond gen setting.')

    # Diffusion args -->
    # Training complexity is O(1) (unaffected), but sampling complexity is O(steps).
    parser.add_argument('--probabilistic_model', type=str, default='diffusion',
                        help='diffusion')
    parser.add_argument('--diffusion_steps', type=int, default=1000)
    parser.add_argument('--diffusion_noise_schedule', type=str, default='polynomial_2',
                        help='learned, cosine')
    parser.add_argument('--diffusion_noise_precision', type=float, default=1e-5,
                        )
    parser.add_argument('--diffusion_loss_type', type=str, default='l2',
                        help='vlb, l2')
    parser.add_argument('--condition_time', type=eval, default=True,
                        help='Whether to condition the diffusion model on time')
    parser.add_argument('--n_layers', type=int, default=6,
                        help='number of layers of the diffusion dynamics')
    parser.add_argument('--inv_sublayers', type=int, default=1,
                        help='number of sublayers of each block of the diffusion dynamics')
    parser.add_argument('--nf', type=int, default=128,
                        help='number of hidden features of the diffusion dynamics')
    parser.add_argument('--normalize_factors', type=eval, default=None,
                        help='normalize factors for [x, categorical, integer]')
    parser.add_argument('--loss_weighting', action='store_true', default=False,
                        help='Whether to use time-dependent loss weighting as proposed in https://arxiv.org/pdf/2309.17296.')
    # VAE args -->
    parser.add_argument('--n_layers_decoder', type=int, default=4,
                        help='number of Equivariant blocks of the decoder of VAE (default=4)')
    parser.add_argument('--n_layers_encoder', type=int, default=1,
                        help='number of Equivariant blocks of the encoder of VAE (default=1)')
    parser.add_argument('--inv_sublayers_vae', type=int, default=2,
                        help='number of GCL layers in each Equivariant block of encoder and decoder of VAE')
    parser.add_argument('--hidden_nf_vae', type=int, default=64,
                        help='number of hidden features of encoder and decoder of VAE')
    parser.add_argument('--latent_nf', type=int, default=2,
                        help='Number of latent features of the VAE')
    parser.add_argument('--kl_weight', type=float, default=0.0,
                        help='weight of KL term in ELBO loss of VAE')
    parser.add_argument('--use_rbf', action='store_true', default=False,
                        help='Whether to use RBF Basis layers for edge prediction.')
    parser.add_argument('--noise_sigma_vae', type=float, default=None,
                        help='value between 0 and 1 controlling how much to corrupt the molecules for VAE training')
    parser.add_argument('--use_focal_loss', action='store_true', default=False,
                        help='Whether to use focal loss for VAE training.')
    parser.add_argument('--encoder_early_stopping', action='store_true', default=False,
                        help='Whether to use encoder early stopping once accuracy > 99%.')
    parser.add_argument('--encode_h_indep', type=eval, default=False,
                        help='Whether to encode the h features independently throguh an MLP and not through EGNN')

    # Regressor args -->
    parser.add_argument('--regression_target', nargs='+', default=[],
                        help='target for the regression model: penalized_logP | qed | drd2')
    parser.add_argument('--condition_time_regressor', type=eval, default=True,
                        help='Whether to condition the regressor on time')
    parser.add_argument('--max_step_regressor', type=int, default=1000)

    args = parser.parse_args()
    return args

def setup_args(args):
    if args.debug:
        args.n_stability_samples = 10

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.resume is not None:
        if 'resumed' not in args.exp_name:
            exp_name = args.exp_name + '_resumed'
        else:
            exp_name = args.exp_name
        start_epoch = args.start_epoch
        resume = args.resume
        wandb_usr = args.wandb_usr
        normalization_factor = args.normalization_factor
        aggregation_method = args.aggregation_method
        noise_sigma_vae = args.noise_sigma_vae

        # for joint fine-tuning
        trainable_ae = args.trainable_ae
        joint_training = args.joint_training
        use_eps_correction = args.use_eps_correction
        lambda_joint_loss = args.lambda_joint_loss
        joint_space = args.joint_space

        if args.train_diffusion:
            model_name = 'diffusion_model'
        else:
            model_name = 'vae'
        with open(join(args.resume, f'last_args_{model_name}.pickle'), 'rb') as f:
            args = pickle.load(f)

        args.exp_name = exp_name
        args.start_epoch = start_epoch
        args.resume = resume
        args.wandb_usr = wandb_usr
        args.noise_sigma_vae = noise_sigma_vae

        # for joint fine-tuning
        args.trainable_ae = trainable_ae
        args.joint_training = joint_training
        args.use_eps_correction = use_eps_correction
        args.lambda_joint_loss = lambda_joint_loss
        args.joint_space = joint_space

        # Careful with this -->
        if not hasattr(args, 'normalization_factor'):
            args.normalization_factor = normalization_factor
        if not hasattr(args, 'aggregation_method'):
            args.aggregation_method = aggregation_method

    return args


# parser.add_argument('--brute_force', type=eval, default=False,
#                     help='True | False')
# parser.add_argument('--actnorm', type=eval, default=True,
#                     help='True | False')
# parser.add_argument('--break_train_epoch', type=eval, default=False,
#                     help='True | False')
# parser.add_argument('--trace', type=str, default='hutch',
#                     help='hutch | exact')
# # EGNN args -->
# # <-- EGNN args
# parser.add_argument('--dequantization', type=str, default='argmax_variational',
#                     help='uniform | variational | argmax_variational | deterministic')
# parser.add_argument('--n_report_steps', type=int, default=1)
# parser.add_argument('--generate_epochs', type=int, default=1,
#                     help='save model')
# parser.add_argument("--conditioning", nargs='+', default=[],
#                     help='arguments : homo | lumo | alpha | gap | mu | Cv' )

# # pp_model args -->
# # ALL args related to the pp_model end with _pp
# parser.add_argument('--tanh_vae', type=eval, default=False,
#                     help='use tanh in the coord_mlp')
# parser.add_argument('--attention_vae', type=eval, default=True,
#                     help='use attention in the EGNN')
# parser.add_argument('--norm_constant_vae', type=float, default=1,
#                     help='diff/(|diff| + norm_constant)')
# parser.add_argument('--sin_embedding_vae', type=eval, default=False,
#                     help='whether using or not the sin embedding')
# parser.add_argument('--normalization_factor_vae', type=float, default=1,
#                     help="Normalize the sum aggregation of EGNN")
# parser.add_argument('--aggregation_method_vae', type=str, default='sum',
#                     help='"sum" or "mean"')
# parser.add_argument('--encoder_pp', type=str, default='egnn',
#                     help='egnn or None')
# parser.add_argument('--edge_head_pp', type=str, default='mlp',
#                     help='mlp or linear')
# parser.add_argument('--edge_head_hidden_dim_pp', type=int, default=32,
#                     help='hidden dimensions of the edge head MLP')
# parser.add_argument('--modify_h_pp', type=eval, default=True,
#                     help='If True, the edge model will also modify the atom types and charges')
# parser.add_argument('--condition_time_pp', type=eval, default=False,
#                     help='True | False')
# parser.add_argument('--lr_vae', type=float, default=5e-4, metavar='N',
#                     help='learning rate')
# parser.add_argument('--use_pp_model', type=eval, default=True,
#                     help='whether to use the edge model')
# # <-- pp_model args

# # VAE args -->
# # <-- VAE args
