from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

use_srun = False
if platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '24:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 22, # 5 to 37
    'end_array_job': 22 #37
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
multi_step_evaluations = ['retro_star_null_root_aligned_steeredfalse_uspto_hard_guidance0_length0_20250915_194601/retro_star']
multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length10_20250915_113318/retro_star']
multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length18_20250916_143341/retro_star']
multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length20_20250916_143534/retro_star']
#multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length15_20250915_113249/retro_star']
multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length10_20250917_084208/retro_star']
multi_step_evaluations = ['retro_star_null_root_aligned_steeredfalse_uspto_hard_guidance0_length0_20250918_121408/retro_star']
multi_step_evaluations = ['retro_star_null_root_aligned_model_steeredtrue_uspto_hard_guidance30_length10_20250921_215757/retro_star']
# classifier guidance
classifier_guidance = 'similarity'
classifier_guidance_similarity_type = 'tanimoto'
classifier_guidance_separator = '.'
n_candidates_to_evaluate = 72
num_classes = 11
classifier_guidance_search_batch_size = 1024
# classifier_checkpoint = os.path.join(PROJECT_ROOT, 
#                                      'checkpoints', 
#                                      'tanimoto_routes_dot_separator_fraction0.4_completion0.8_augmentation5',
#                                      'checkpoint_90.pt')
classifier_checkpoint = os.path.join(PROJECT_ROOT, 
                                     'checkpoints', 
                                     'tanimoto_routes_fraction1.0_thresh500_completion0.8_augmentation8',
                                     'checkpoint_100.pt')
# onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 
#                                      'checkpoints', 
#                                      'rsmiles_50k_checkpoints', 
#                                      'USPTO_50K_PtoR.pt')
onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 
                                     'checkpoints', 
                                     'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5', 
                                     'model.product-reactants_step_250000.pt')

# search
search = 'retro_star'
steered = 'false'
diversity_radius = 0.99
# evaluation
single_step_evaluation = '50k'
forward_model_dir = 'forward_rsmiles_mit'
compute_classifier_score = 'false'
# route
route_dataset_type = 'uspto_190'
route_dataset_path = 'in_json/test_with_tanimoto_weight1.json'

for multi_step_evaluation in multi_step_evaluations:
    exp_name = multi_step_evaluation.split('/')[0]
    experiment_name = f'{exp_name}_time{time_stamp}' 
    script_args = {"script_dir": SCRIPT_DIR,
                    "use_torchrun": 'false',
                    "args": {               
                        "general.experiment_name": multi_step_evaluation,
                        "route_dataset.type": route_dataset_type,
                        "route_dataset.path": route_dataset_path,
                        "classifier_guidance": classifier_guidance,
                        "classifier_guidance.similarity_type": classifier_guidance_similarity_type,
                        "classifier_guidance.dataset.separator": classifier_guidance_separator,
                        "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                        "classifier_guidance.checkpoint_path": classifier_checkpoint,
                        "classifier_guidance.model.num_classes": num_classes,
                        "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                        "classifier_guidance.search_batch_size": classifier_guidance_search_batch_size,
                        "single_step_evaluation": single_step_evaluation,
                        "single_step_evaluation.forward_model_dir": forward_model_dir,
                        "single_step_evaluation.compute_classifier_score": compute_classifier_score,
                        "multi_step_evaluation.route_start_idx": '$start_idx',
                        "multi_step_evaluation.route_end_idx": '$end_idx',
                        "search": search,
                        "search.steered": steered,
                        "search.diversity_radius": diversity_radius
                    },
                    "variables": {'targets_per_job': 5, # 5 molecules per job in array
                                    'offset': 0,
                                    'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                    'end_idx': '$((start_idx+targets_per_job))'}
                }
    script_args['script_name'] = 'evaluate_multistep.py'
    task = 'evaluate_multistep'
    slurm_args['job_name'] = 'evaluate_' + experiment_name
    output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)


#"general.experiment_name": "root_aligned_steeredtrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_152427/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredtrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_153512/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredFalse_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_095204/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredtrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_153512/call1000_time1800",
#"general.experiment_name": "root_aligned_steeredTrue_uspto_hard_reaction_type_guided_rsmiles_50k_checkpoints_20250809_101022/call1000_time1800",

# multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length10_20250819_160803/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length10_20250820_113606/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length12_20250820_113611/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length15_20250820_113617/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length5_20250820_113627/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_113633/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length12_20250820_113638/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length15_20250820_113643/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length12_20250820_113548/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length12_20250820_113542/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length10_20250820_113534/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length5_20250820_113601/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length5_20250820_113554/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length15_20250820_113509/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length12_20250820_113502/call1000_time1800',
#                           'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length10_20250820_113443/call1000_time1800'
#                           ]
# #multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_145456/call1000_time1800']
# multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_160551/call1000_time1800']
# multi_step_evaluations = ['root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250821_124229/call1000_time1800']
# multi_step_evaluations = ['desp_f2f_neuralsym_steeredfalse_uspto_hard_guidance1.5_length5_20250913_150832']
# multi_step_evaluations = ['desp_f2e_neuralsym_steeredfalse_uspto_hard_guidance1.5_length5_20250913_160411']
# multi_step_evaluations = ['retro_star_null_root_aligned_steeredtrue_uspto_hard_guidance30_length15_20250915_113249/retro_star']