from argparse import ArgumentParser, FileType

import yaml


def parse_training_args():
    parser = ArgumentParser()
    # train data arguments
    parser.add_argument('--config', type=FileType(mode='r'), default=None)
    parser.add_argument('--data_dir', type=str, default='data/mdsim/', help='Folder containing the simulation data')
    parser.add_argument('--data_size', type=int, default=1000000, help='Number of simulation steps performed, 1 step corresponds to 1 femtosecond') 
    parser.add_argument('--md_device', type=str, default='CUDA', help='CUDA or CPU') 
    parser.add_argument('--data_save_frequency', type=int, default=120, help='Frequency after which the state is saved')
    parser.add_argument('--data_temperature', type=int, default=300, help='Temperature of the system in K')
    parser.add_argument('--testsystem', type=str, default='implicit', help='Testsystem for the Simulation, can be one of vacuum,implicit,explicit')
    parser.add_argument('--num_dataloader_workers', type=int, default=0, help='Number of workers for the dataloader')
    parser.add_argument("--dataloader_drop_last", type=bool, default=True, help="Drop last batch if it is smaller than the batch size")
    parser.add_argument("--save_pdb",  action="store_true", default=False, help="Save the md trajectory as pdb file")
    parser.add_argument("--feature_transform", type=str, default="standardize", help="Can be either none/standardize/normalize . Wether to transform the features before training")

    # model arguments
    parser.add_argument("--no_vae", action="store_true", default=False, help="Train a simple autoencoder instead of a VAE, i.e. no KL loss")
    parser.add_argument("--no_propagator", action="store_true", default=False, help="Only train the encoder and decoder, no propagator is created")

    # Encoder
    parser.add_argument("--encoder_embedding_size", type=int, default=32, help="Size of the embedding vector for the encoder")
    parser.add_argument("--edge_embedding_size", type=int, default=2, help="Size of the embedding vector for the bond-bond edges")
    parser.add_argument("--latent_embedding_size", type=int, default=3, help="Size of the embedding vector for the latent space")
    parser.add_argument("--graph_representation", type=str, default="internal", help="Can be either internal/extrinsic . Wether to represent the system as graph of extrinsic or internal coordinates")

    # Extrinsic representation
    parser.add_argument("--num_conv_layers", type=int, default=2, help="Number of convolutional layers in the encoder")
    parser.add_argument("--sh_lmax", type=int, default=1, help="Maximum spherical harmonics degree")
    parser.add_argument("--ns", type=int, default=2, help="Number of scalar features, corresponds to atom atttributes, i.e. if ns=2 the first two attributes will be used as scalar features")
    parser.add_argument("--nv", type=int, default=2, help="Number of vector features")
    parser.add_argument("--in_edge_features", type=int, default=1, help="Number of edge features for radius graph, should be equal to the number of node attributes")
    parser.add_argument("--use_set2set_pooling", action="store_true", default=False, help="Use set2set pooling instead of global pooling")

    # Decoder

    # Propagator
    parser.add_argument("--tau", type=int, default=10, help="Offset between 2 states (tau),i.e. number of frames between t and t+tau")
    parser.add_argument("--sequence_length", type=int, default=5, help="Length of the sequence that is used to predict the next state, i.e. 'warmup' length")
    parser.add_argument("--propagator_type", type=str, default="linear", help="Type of the propagator, can be one of linear, lstm, tbd")
    parser.add_argument("--propagator_hidden_size", type=int, default=20, help="Size of the hidden state of the propagator")
    parser.add_argument("--propagator_num_layers", type=int, default=3, help="Number of layers of the propagator")
    parser.add_argument("--propagator_dropout", type=float, default=0.1, help="Dropout rate of the propagator")

    ######################################################################################################################################################
    # Training arguments


    # wandb
    parser.add_argument("--run_name", type=str, default="default", help="Name of the run for Weights and Biases")
    parser.add_argument("--wandb_project", type=str, default="latent-md", help="Name of the project for Weights and Biases")
    parser.add_argument("--wandb", action="store_true", default=False, help="Use Weights and Biases for logging")
    parser.add_argument("--log_energy", action="store_true", default=False, help="Log the energy and free energy difference of the predicted conformations")
    parser.add_argument("--log_rmsd", action="store_true", default=False, help="Log the rmsd of the predicted conformations to ground truth")
    parser.add_argument("--log_latent_tica", action="store_true", default=False, help="Log the tica projection of the latent space")
    
    # Loss
    # TODO: to tune 
    parser.add_argument("--bond_width_weight", type=float, default=1.0, help="Weight of the bond width loss")
    parser.add_argument("--bond_angles_weight", type=float, default=1.0, help="Weight of the bond angles loss")
    parser.add_argument("--torsion_angles_weight", type=float, default=1.0, help="Weight of the torsion loss")
    parser.add_argument("--kl_weight", type=float, default=1.0, help="Weight of the KL loss")
    parser.add_argument("--propagator_weight", type=float, default=1.0, help="Weight of the propagator loss")

    # Training
    parser.add_argument("--num_epochs", type=int, default=100, help="Number of epochs for training")
    parser.add_argument("--batch_norm", action="store_true", default=True, help="Use batch normalization")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate")
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate for training')
    parser.add_argument('--use_ema', action='store_true', default=False, help='Whether or not to use ema for the model weights')
    parser.add_argument('--ema_rate', type=float, default=0.999, help='decay rate for the exponential moving average model parameters ')
    parser.add_argument('--adamw', action='store_true', default=False, help='Whether or not to use adamw optimizer, if False adam is used')
    parser.add_argument('--scheduler', type=str, default='none', help='Scheduler for the learning rate, can be one of [none, plateau]')
    parser.add_argument('--scheduler_patience', type=int, default=5, help='Patience for the scheduler')
    parser.add_argument('--w_decay', type=float, default=0.0, help='Weight decay added to loss')

    # Model saving
    parser.add_argument("--log_dir", type=str, default="workdir", help="Folder for saving the model")

    args = parser.parse_args()


    if args.config:
        config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
        arg_dict = args.__dict__
        for key, value in config_dict.items():
            if isinstance(value, list):
                for v in value:
                    arg_dict[key].append(v)
            else:
                arg_dict[key] = value
    

    if args.no_propagator and args.tau != 0:
        raise ValueError("If no propagator is used, tau should be 0")

    return args
