import os
from pathlib import Path
from datetime import datetime
from slurm_utils import create_and_submit_batch_job, get_platform_info

start_array_job = 0 # 51
end_array_job = 37 # 51
default_num_results = 100
targets_per_job = 5
seed = 42 # 42, 101, 90
offset = 0

# NOTE: do not use gpu on lumi, faiss-gpu does not work there
PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]
SCRIPT_DIR = 'scripts'
slurm_args = get_platform_info(use_gpu=False)
slurm_args.update({
    'use_srun': True,
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'time': '12:00:00',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 0,
    'mem': '300G', # 50G not enough for uspto_full
    'start_array_job': start_array_job, 
    'end_array_job': end_array_job
})
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# in puhti: gln, localretro, retroknn
# lumi: megan, mhnreact, graph2edits, neuralsym, root_aligned, chemformer
single_step_model = 'neuralsym'
#model_dir = None
model_dir = os.path.join(
    PROJECT_ROOT,
    'checkpoints',
    'neuralsym',
    'model_retro.pt'
)
# model_dir = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints', 
#     'rsmiles_50k_checkpoints' # rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5
# )
model_dir_id = model_dir.split('/')[-1] if model_dir else 'None'

search = 'desp'
steered = 'false' # do the remaining ones for false
value_function = 'desp'
heuristic = 'value_function'
dummy_inventory = 'true' if slurm_args['interactive'] else 'false'
strategy = 'f2e' # f2e, f2f
single_step_evaluation = '50k'
multi_step_evaluation = 'route'
forward_model_dir = 'forward_rsmiles_mit'
dataset = 'uspto_190'
dataset_type = 'uspto_hard'
route_path = 'desp_data/uspto_190_targets.txt'
limit_reaction_model_calls = 100 # 100
time_limit_s = 600
limit_iterations = 500
max_routes_to_extract = 100
stop_on_first_solution = False

experiment_name = f'{single_step_model}_{model_dir_id}_steered{steered}_{dataset_type}_{time_stamp}' 
# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
script_args = {"script_dir": SCRIPT_DIR,
                "use_torchrun": 'false',
                "args": {
                    'general.experiment_name': experiment_name,
                    'single_step_model.model_type': single_step_model,
                    'single_step_model.model_dir': model_dir,
                    'single_step_model.default_num_results': default_num_results,
                    'search': search,
                    'search.steered': steered,
                    'search.dummy_inventory': dummy_inventory,
                    'search.heuristic': heuristic,
                    'search.strategy': strategy,
                    'search.limit_reaction_model_calls': limit_reaction_model_calls,
                    'search.time_limit_s': time_limit_s,
                    'search.limit_iterations': limit_iterations,
                    'search.max_routes_to_extract': max_routes_to_extract,
                    'search.stop_on_first_solution': stop_on_first_solution,
                    'value_function': value_function,
                    'multi_step_evaluation': multi_step_evaluation,
                    'single_step_evaluation': single_step_evaluation,
                    'single_step_evaluation.forward_model_dir': forward_model_dir,
                    'route_dataset': dataset,
                    'route_dataset.type': dataset_type,
                    'route_dataset.path': route_path,
                    'route_dataset.start_idx': '$start_idx',
                    'route_dataset.end_idx': '$end_idx'
                },
                "variables": {'targets_per_job': targets_per_job, # 5 molecules per job in array
                                'offset': 0,
                                'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                'end_idx': '$((start_idx+targets_per_job))'}}
script_args['script_name'] = 'search.py'
task = 'search'
slurm_args['job_name'] = f'search{search}_{experiment_name}'
slurm_args['output_dir'] = os.path.join(slurm_args['output_dir'], f'search{search}', single_step_model, experiment_name)
output = create_and_submit_batch_job(slurm_args, script_args, interactive=slurm_args['interactive'])
