from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
platform = args.platform
SCRIPT_DIR = 'scripts'

if platform == 'lumi':
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    puhti_module = None
    venv_path = 'property-funnel'
    container = 'property-funnel.sif'
elif platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
else:
    raise ValueError(f'Platform {platform} not supported')

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '05:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}
###### reaction type
steered = "false"
model_dir = "rsmiles_50k_checkpoints"
num_results = 100
search = "strychnine"
dummy_inventory = "true"
classifier_guidance = "reaction_type"
classifier_guidance_checkpoint_path = "/Users/laabidn1/multiguide/checkpoints/reaction_type_completion0.8_augmentation5/checkpoint_180.pt"
num_classes = "11"
n_candidates_to_evaluate = "72"
search_batch_size = "1024"
single_step_model = "root_aligned"
inventory_file = "fukuyama_tree_inventory" # do not include extension (.csv)
limit_reaction_model_calls = 500
time_limit_s = 6000

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f'{search}_{classifier_guidance}_steered{steered}_calls{limit_reaction_model_calls}_time{time_limit_s}_{time_stamp}'  
# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
script_args = {"script_dir": SCRIPT_DIR,
               "use_torchrun": 'false',
               "args": {"single_step_model": experiment_name,
                        "single_step_model": single_step_model,
                        "single_step_model.model_dir": model_dir,
                        "single_step_model.default_num_results": num_results,
                        "search": search,
                        "search.steered": steered,
                        "search.dummy_inventory": dummy_inventory,
                        "search.inventory_file": inventory_file,
                        "search.limit_reaction_model_calls": limit_reaction_model_calls,
                        "search.time_limit_s": time_limit_s,
                        "classifier_guidance": classifier_guidance,
                        "classifier_guidance.checkpoint_path": classifier_guidance_checkpoint_path,
                        "classifier_guidance.model.num_classes": num_classes,
                        "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                        "classifier_guidance.search_batch_size": search_batch_size
                }
            }
script_args['script_name'] = 'search_for_one_molecule.py'
task = 'search'
slurm_args['job_name'] = f'{task}_{experiment_name}'
output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
