from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse
from slurm_utils import create_and_submit_batch_job, get_platform_info

start_array_job = 0 # 17
end_array_job = 37 # 37 # 
default_num_results = 100
targets_per_job = 5
seed = 42 # 42, 101, 90
offset = 0

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]
SCRIPT_DIR = 'scripts'
slurm_args = get_platform_info(use_gpu=True)
slurm_args.update({
    'use_srun': True,
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'time': '02:00:00',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'start_array_job': start_array_job, 
    'end_array_job': end_array_job
})

# search
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# evaluation
single_step_evaluation = '50k'
forward_model_dir = 'forward_rsmiles_mit'
# route
route_dataset_type = 'uspto_190'
route_dataset_path = 'in_json/test_processed.json'
#route_dataset_path = 'in_json/test_13.json'
dummy_inventory = 'true' if slurm_args['interactive'] else 'false'
single_step_evaluation = '50k'
forward_model_dir = 'forward_rsmiles_mit'
classifier_guidance = 'reaction_type'
similarity_type = ''
with_starting_material = 'false'
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#default_num_results = 1
single_step_evaluation = '50k'
guidance_scale = 0 # TODO: 10 is done
min_length_for_guidance = 0
eos_penalty = -10.0
n_candidates_to_evaluate = 72
num_classes = 11
search_batch_size = 1024
max_routes_to_extract = 100
diversity_radius = 0.
# onmt_checkpoint_path = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints', 
#     'rsmiles_50k_checkpoints',
#     'USPTO_50K_PtoR.pt'
# )
# classifier_checkpoint = os.path.join(
#     PROJECT_ROOT, 
#     'checkpoints', 
#     'reaction_type_completion0.8_augmentation5',
#     'checkpoint_308.pt'
# )
onmt_checkpoint_path = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5',
    'model.product-reactants_step_250000.pt'
)
classifier_checkpoint = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'reaction_type_routes_fraction1.0_thresh500_completion0.8_augmentation5',
    'checkpoint_25.pt'
)
experiment_group = 'search'
search_type = 'retro_star'
search = 'retro_star'
strategy = 'None'
experiment_params = 'retro_star'
experiment_name = 'rootaligned_rsmiles_50k_checkpoints_steeredfalse_uspto_hard_20251103_160715'
experiment_name = 'megan_None_steeredfalse_uspto_hard_20251104_180643'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredfalse_uspto_hard_20251103_160804'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_numModelCalls100_uspto_hard_20251106_154721'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_numModelCalls100_uspto_hard_20251108_183816'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredfalse_numModelCalls100_uspto_hard_20251108_190340'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredfalse_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251108_194243'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredfalse_filteredfalse_guidedfalse_numModelCalls100_uspto_hard_20251108_194632'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251111_145653'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251111_171125'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredtrue_guidedtrue_numModelCalls100_uspto_hard_20251111_205343'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251112_150141'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251112_172908'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251112_181942'
experiment_name = 'steeredtrue_filteredfalse_guidedtrue_guidance0.3_minLength10_numModelCalls100_uspto_hard_20251112_222018'
experiment_name = 'steeredtrue_filteredfalse_guidedtrue_guidance1.5_minLength10_numModelCalls100_uspto_hard_20251112_222030'
experiment_name = 'steeredtrue_filteredtrue_guidedtrue_guidance0.3_minLength15_numModelCalls100_uspto_hard_20251112_222047'
experiment_name = 'steeredtrue_filteredfalse_guidedfalse_guidance0.3_length15_numModelCalls100_uspto_hard_20251113_161033'
experiment_name = 'steeredfalse_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251113_174320'
experiment_name = 'neuralsym_model_retro.pt_steeredfalse_uspto_hard_20251103_145934'
experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredfalse_uspto_hard_20251103_160804'
experiment_name = 'similarity_steeredfalse_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251114_161519'
experiment_name = 'similarity_tanimoto_like_tango_steeredfalse_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251114_182257'
experiment_name = 'similarity_max_tanimoto_steeredfalse_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251114_182151'
experiment_name = 'similarity_max_tanimoto_steeredtrue_filteredfalse_guidedtrue_guidance3.0_length15_numModelCalls100_uspto_hard_20251114_203306'
experiment_name = 'similarity_max_tanimoto_steeredtrue_filteredfalse_guidedtrue_guidance20_length5_numModelCalls100_uspto_hard_20251116_002115'
experiment_name = 'similarity_max_tanimoto_steeredtrue_filteredfalse_guidedtrue_guidance30_length5_numModelCalls100_uspto_hard_20251116_001512'
experiment_name = 'similarity_max_tanimoto_steeredtrue_filteredtrue_guidedtrue_guidance20_length5_numModelCalls100_uspto_hard_20251116_144109'
experiment_name = 'similarity_max_tanimoto_steeredfalse_filteredtrue_guidedtrue_guidance20_length10_numModelCalls100_uspto_hard_20251116_144032'
experiment_name = 'similarity_max_tanimoto_steeredfalse_filteredtrue_guidedfalse_guidance20_length10_numModelCalls100_uspto_hard_20251116_144040'
experiment_name = 'steeredtrue_filteredtrue_guidedfalse_guidance0.3_length15_numModelCalls100_uspto_hard_20251116_144248'
experiment_name = 'steeredfalse_filteredtrue_guidedfalse_guidance0.3_length15_numModelCalls100_uspto_hard_20251116_144239'
experiment_name = 'test_13_steeredfalse_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_205158'
experiment_name = 'test_13_steeredfalse_filteredtrue_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205706'
experiment_name = 'test_13_steeredtrue_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_204858'
experiment_name = 'test_13_steeredtrue_filteredfalse_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205655'
experiment_name = 'test_13_steeredtrue_filteredtrue_guidedtrue_guidance0.5_length10_numModelCalls100_uspto_hard_20251117_210119'
experiment_name = 'chemformer_steeredfalse_filteredfalse_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251118_181614'
#experiment_name = 'megan_steeredfalse_filteredfalse_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251118_181626'
#experiment_name = 'graph2edits_steeredfalse_filteredfalse_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251118_181705'
#experiment_name = 'megan_steeredfalse_filteredfalse_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251118_202145'
#experiment_name = 'mhnreact_steeredfalse_filteredfalse_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251118_181656'
experiment_name = 'rootaligned_steeredfalse_filteredfalse_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251119_122229'
experiment_name = 'neuralsym_steeredfalse_filteredfalse_guidedtrue_reaction_type_guidance0_length0_numModelCalls100_uspto_hard_20251119_190910'
experiment_name = 'neuralsym_steeredfalse_filteredtrue_guidedfalse_reaction_type_guidance0_length0_numModelCalls100_uspto_hard_20251119_184701'
experiment_name = 'neuralsym_steeredfalse_filteredtrue_guidedtrue_reaction_type_guidance0_length0_numModelCalls100_uspto_hard_20251119_190831'
experiment_name = 'neuralsym_steeredfalse_filteredtrue_guidedfalse_reaction_type_guidance0_length0_numModelCalls100_uspto_hard_20251119_204051'
experiment_name = 'similarity_max_tanimoto_neuralsym_steeredfalse_filteredfalse_guidedtrue_guidance0_length0_numModelCalls100_uspto_hard_20251121_062633'
experiment_name = 'similarity_max_tanimoto_neuralsym_steeredfalse_filteredtrue_guidedfalse_guidance0_length0_numModelCalls100_uspto_hard_20251121_062711'
experiment_name = 'similarity_max_tanimoto_neuralsym_steeredfalse_filteredtrue_guidedtrue_guidance0_length0_numModelCalls100_uspto_hard_20251121_062718'
#experiment_name = 'neuralsym_steeredfalse_filteredtrue_guidedtrue_reaction_type_guidance0.5_length5_numModelCalls100_uspto_hard_20251119_174138'
#experiment_name = 'steeredfalse_filteredtrue_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251113_174334'
#experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251112_160437'
#experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251111_171338'
#experiment_name = 'rootaligned_rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5_steeredtrue_filteredfalse_guidedtrue_numModelCalls100_uspto_hard_20251108_194308'
#experiment_name = 'neuralsym_model_retro.pt_steeredfalse_uspto_hard_20251103_145934' 
# search_type = 'desp'
# strategy = 'f2e'
# experiment_params = 'desp'
# experiment_name = 'neuralsym_model_retro.pt_steeredfalse_uspto_hard_20251104_140117' 

script_args = {"script_dir": SCRIPT_DIR,
                "use_torchrun": 'false',
                "args": {               
                    "route_dataset.type": route_dataset_type,
                    "route_dataset.path": route_dataset_path,
                    "general.experiment_group": experiment_group,
                    "general.experiment_params": experiment_params, 
                    "general.experiment_name": experiment_name,    
                    "classifier_guidance": classifier_guidance,
                    "classifier_guidance.similarity_type": similarity_type,
                    "classifier_guidance.experiment_name": experiment_name,
                    "classifier_guidance.checkpoint_path": classifier_checkpoint,
                    "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                    "classifier_guidance.guidance_scale": guidance_scale,
                    "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                    "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                    "classifier_guidance.eos_penalty": eos_penalty,
                    "classifier_guidance.model.num_classes": num_classes,
                    "classifier_guidance.search_batch_size": search_batch_size,
                    "classifier_guidance.with_starting_material": with_starting_material,
                    "single_step_evaluation": single_step_evaluation,
                    "multi_step_evaluation.route_start_idx": '$start_idx' if not slurm_args['interactive'] else 2,
                    "multi_step_evaluation.route_end_idx": '$end_idx' if not slurm_args['interactive'] else 3,
                    "single_step_evaluation.forward_model_dir": "forward_rsmiles_mit",
                    "single_step_model.default_num_results": default_num_results,
                    "search": search,
                    "search.strategy": strategy,
                    "search.type": search_type,
                    "search.max_routes_to_extract": max_routes_to_extract,
                    "search.diversity_radius": diversity_radius
                },
                "variables": {'targets_per_job': targets_per_job, # 5 molecules per job in array
                                'offset': 0,
                                'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                'end_idx': '$((start_idx+targets_per_job))'}
            }
script_args['script_name'] = 'evaluate_multistep.py'
task = 'evaluate_multistep'
slurm_args['job_name'] = 'evaluate_' + experiment_name
slurm_args['output_dir'] = os.path.join(slurm_args['output_dir'], f'search{search_type}', f'evaluate_{experiment_name}')
output = create_and_submit_batch_job(slurm_args, script_args, interactive=slurm_args['interactive'])

#"general.experiment_name": "root_aligned_steeredtrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_152427/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredtrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_153512/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredFalse_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_095204/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredtrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_153512/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredTrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_101022/call1000_time1800",

# multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length10_20250819_160803/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length10_20250820_113606/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length12_20250820_113611/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length15_20250820_113617/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length5_20250820_113627/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_113633/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length12_20250820_113638/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length15_20250820_113643/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length12_20250820_113548/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length12_20250820_113542/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length10_20250820_113534/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length5_20250820_113601/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length5_20250820_113554/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length15_20250820_113509/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length12_20250820_113502/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length10_20250820_113443/call1000_time1800'
#                           ]
# #multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_145456/call1000_time1800']
# multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_160551/call1000_time1800']
# multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250821_124229/call1000_time1800']
# multi_step_evaluations = ['desp_f2f_neuralsym_steeredfalse_uspto_hard_guidance1.5_length5_20250913_150832']
# multi_step_evaluations = ['desp_f2e_neuralsym_steeredfalse_uspto_hard_guidance1.5_length5_20250913_160411']
# multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length15_20250915_113249/retro_star']