from argparse import ArgumentParser, FileType, Namespace

import yaml
from openmmtools.testsystems import (AlanineDipeptideExplicit,
                                     AlanineDipeptideImplicit,
                                     AlanineDipeptideVacuum)

from datasets.aldp import ALDPDataset

testsystems = {"vacuum":AlanineDipeptideVacuum,"implicit":AlanineDipeptideImplicit,"explicit":AlanineDipeptideExplicit}

parser = ArgumentParser()
parser.add_argument('--config', type=FileType(mode='r'), default=None)
parser.add_argument('--data_dir', type=str, default='data/mdsim/', help='Folder containing the simulation data')
parser.add_argument('--data_size', type=int, default=1000000, help='Number of simulation steps performed, 1 step corresponds to 1 femtosecond') 
parser.add_argument('--md_device', type=str, default='CUDA', help='CUDA or CPU') 
parser.add_argument('--data_save_frequency', type=int, default=120, help='Frequency after which the state is saved')
parser.add_argument('--data_temperature', type=int, default=300, help='Temperature of the system in K')
parser.add_argument('--testsystem', type=str, default='implicit', help='Testsystem for the Simulation, can be one of vacuum,implicit,explicit')
parser.add_argument("--remove_hydrogens", type=bool, default=False, help="Remove hydrogens from the system")
parser.add_argument("--save_pdb", type=bool, default=False, help="Save the md trajectory as pdb file")
parser.add_argument("--graph_representation", type=str, default="internal", help="Can be either internal/extrinsic . Wether to represent the system as graph of extrinsic or internal coordinates")
parser.add_argument("--tau", type=int, default=10, help="Time offset between 2 frames (tau)")

args = parser.parse_args()

if args.testsystem not in testsystems:
    raise KeyError("The specified testsystem is not implemented for the Alanine Dipeptide Simulation")

if args.config:
    config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
    arg_dict = args.__dict__
    for key, value in config_dict.items():
        if isinstance(value, list):
            for v in value:
                arg_dict[key].append(v)
        else:
            arg_dict[key] = value


if __name__ =="__main__":
    testsystem = testsystems[args.testsystem](constraints=None)
    dataset = ALDPDataset(args,testsystem)
