from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse
import json

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'puhti' or platform == 'mahti':
    project = 'project_2007775' # 2007775
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')

use_srun = False
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '10:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#time_stamp = '20250513_212436'
# problem: chemformer, retroknn, gln, localretro
# ok: megan, mhnreact, graph2edits
single_step_models = ['megan', 'mhnreact', 'chemformer', 'graph2edits', 'localretro', 'retroknn', 'gln']
single_step_models = ['graph2edits', 'megan', 'mhnreact', 'neuralsym', 'root_aligned'] 
single_step_models = ['root_aligned']
model_dir = 'rsmiles_50k_checkpoints'
default_num_results = 100

# classifier guidance
classifier_guidance = 'similarity'
classifier_guidance_similarity_type = 'enforce_starting_material'
classifier_guidance_enforce_starting_material_at_depth = 18
use_ground_truth_node_depth = 'true'
classifier_guidance_enforce_starting_material_scale = 100
classifier_guidance_enforce_starting_material_min_length = 0
separator = '.'
guidance_scale = 30 # 2
min_length_for_guidance = 10
onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', 'rsmiles_50k_checkpoints', 'USPTO_50K_PtoR.pt')
checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', \
'tanimoto_routes_dot_separator_fraction0.4_completion0.8_augmentation5', 'checkpoint_90.pt')
n_candidates_to_evaluate = 72
num_classes = 11

search = 'retro_star'
steered = 'true' # do the remaining ones for false
value_function = 'retro_star'
heuristic = 'value_function'
dummy_inventory = 'true' if args.interactive else 'false'
strategy = 'null'

single_step_evaluation = '50k'
multi_step_evaluation = 'route'
forward_model_dir = 'forward_rsmiles_mit'

dataset = 'uspto_190'
dataset_type = 'uspto_hard'
route_path = 'desp_data/uspto_190_targets.txt'

for single_step_model in single_step_models:
    experiment_name = f'{search}_{strategy}_{single_step_model}_steered{steered}_{dataset_type}_enforce_guidance{guidance_scale}_length{min_length_for_guidance}_{time_stamp}' 
    # NOTE: here should not modify any of the hydra configs manually...
    # NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
    script_args = {"script_dir": SCRIPT_DIR,
                    "use_torchrun": 'false',
                    "args": {'general.experiment_name': experiment_name,
                            'single_step_model': single_step_model,
                            'single_step_model.model_dir': model_dir,
                            'single_step_model.default_num_results': default_num_results,
                            'search': search,
                            'search.steered': steered,
                            'search.dummy_inventory': dummy_inventory,
                            'search.heuristic': heuristic,
                            'search.strategy': strategy,
                            'value_function': value_function,
                            'multi_step_evaluation': multi_step_evaluation,
                            'single_step_evaluation': single_step_evaluation,
                            'single_step_evaluation.forward_model_dir': forward_model_dir,
                            'route_dataset': dataset,
                            'route_dataset.type': dataset_type,
                            'route_dataset.path': route_path,
                            'route_dataset.start_idx': '$start_idx',
                            'route_dataset.end_idx': '$end_idx',
                            'classifier_guidance': classifier_guidance,
                            'classifier_guidance.checkpoint_path': checkpoint_path,
                            'classifier_guidance.onmt_checkpoint_path': onmt_checkpoint_path,
                            'classifier_guidance.model.num_classes': num_classes,
                            'classifier_guidance.min_length_for_guidance': min_length_for_guidance,
                            'classifier_guidance.guidance_scale': guidance_scale,
                            'classifier_guidance.n_candidates_to_evaluate': n_candidates_to_evaluate,
                            'classifier_guidance.dataset.separator': separator,
                            'classifier_guidance.similarity_type': classifier_guidance_similarity_type,
                            'classifier_guidance.search_batch_size': 1024,
                            'classifier_guidance.use_ground_truth_node_depth': use_ground_truth_node_depth,
                            'classifier_guidance.enforce_starting_material_at_depth': classifier_guidance_enforce_starting_material_at_depth,
                            'classifier_guidance.enforce_starting_material_scale': classifier_guidance_enforce_starting_material_scale,
                            'classifier_guidance.enforce_starting_material_min_length': classifier_guidance_enforce_starting_material_min_length},
                    "variables": {'targets_per_job': 1, # 5 molecules per job in array
                                  'offset': 2,
                                  'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                  'end_idx': '$((start_idx+targets_per_job))'}}
    script_args['script_name'] = 'search.py'
    task = 'search'
    slurm_args['job_name'] = experiment_name
    output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)