#!/usr/bin/env python
import os.path
import shutil
from deeperwin.cli import main
from deeperwin.run_tools.geometry_database import load_geometries, load_datasets, Geometry
from deeperwin.configuration import Configuration
from deeperwin.checkpoints import load_run
import ruamel.yaml

# Settings
dry_run = False
datasets = [
    "TinyMol_CNO_rot_dist_test_out_of_distribution_4geoms",
#    "Kinal_Piecuch_Bicyclobutane",
#    "Test_Set_5heavy_atoms_no_ring_4geoms_dist",
#    "Test_Set_6heavy_atoms_no_ring_4geoms_dist",
#    "Test_Set_7heavy_atoms_no_ring_4geoms_dist",
    ]

checkpoint = "midimol_2023-05-01_699torsion_nc_by_std_256k.zip"
phisnet_checkpoint = "phisnet_3LayerL2_47kGeoms_174Epochs.zip"

# Constants
reuse_config_fname = "config_reuse_template.yml"
calc_name = checkpoint.split("/")[-1].replace(".zip", "")
calc_dir = "reuse_" + calc_name

# Load model config from checkpoint and use it to replace the reuse model config
with open(reuse_config_fname) as f:
    reuse_config = ruamel.yaml.YAML().load(f)
checkpoint_config = load_run(checkpoint, load_pkl=False, parse_config=False).config
reuse_config["model"] = checkpoint_config["model"]
reuse_config["reuse"]["path"] = checkpoint
reuse_config["reuse"]["path_phisnet"] = phisnet_checkpoint

# Create directory for reusing from one specific checkpoint
if not os.path.isdir(calc_dir):
    os.mkdir(calc_dir)
os.chdir(calc_dir)

# Get all geometries for single points
all_geometries = load_geometries()
all_datasets = load_datasets()
geometry_hashes = []
for dataset in datasets:
  if dataset in all_geometries:
      geometry_hashes += [dataset]
  else:
      geometry_hashes += all_datasets[dataset].get_hashes()

# Submit the actual calculations
for geom_hash in geometry_hashes:
    geom = all_geometries[geom_hash]
    n_heavy = geom.n_heavy_atoms
    if n_heavy >= 5:
        reuse_config["dispatch"]["queue"] = "a100"
    else:
        reuse_config["dispatch"]["queue"] = "a40"
    n_walkers = {5: 1500, 6: 800, 7: 500}
    
    reuse_config["optimization"]["mcmc"]["n_walkers"] = n_walkers.get(n_heavy, 2048)
    reuse_config["optimization"]["intermediate_eval"]["mcmc"]["n_walkers"] = n_walkers.get(n_heavy, 2048)
    reuse_config["experiment_name"] = calc_dir + f"_{n_heavy}heavy"

    # Write config and submit job
    with open("config.yml", "w") as f:
        ruamel.yaml.YAML().dump(reuse_config, f)
    print(f"Generating job for {geom_hash}")
    cmd = "setup -i config.yml"
    cmd += f" -p physical {geom_hash}"
    if dry_run:
        cmd += " --dry-run"
    main(cmd)
os.chdir("..")
  
