import os
from pathlib import Path
from datetime import datetime
from slurm_utils import create_and_submit_batch_job, get_platform_info

start_array_job = 9 # 9, 36, 50, 73
end_array_job = 9
default_num_results = 100 # 100
targets_per_job = 1
seed = 42 # 42, 101, 90
offset = 0

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]
SCRIPT_DIR = 'scripts'
slurm_args = get_platform_info(use_gpu=True)
slurm_args.update({
    'use_srun': True,
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'time': '01:00:00',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'start_array_job': start_array_job, 
    'end_array_job': end_array_job
})

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# in puhti: gln, localretro, retroknn
# lumi: megan, mhnreact, graph2edits, neuralsym, root_aligned, chemformer
single_step_model = 'rootaligned'
#model_dir = None
# model_dir = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints',
#     'neuralsym',
#     'model_retro.pt'
# )
# model_dir = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints', 
#     'rsmiles_50k_checkpoints' # rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5
# )
model_dir = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5'
)
model_dir_id = model_dir.split('/')[-1] if model_dir else 'None'

search = 'retro_star'
value_function = 'retro_star'
heuristic = 'value_function'
dummy_inventory = 'true' if slurm_args['interactive'] else 'false'
strategy = 'null'
single_step_evaluation = '50k'
multi_step_evaluation = 'route'
forward_model_dir = 'forward_rsmiles_mit'
dataset = 'uspto_190'
dataset_type = 'uspto_hard'
#route_path = 'desp_data/uspto_190_targets.txt'
route_path = 'uspto_190/in_json/test_processed.json'
limit_reaction_model_calls = 100 # 100
time_limit_s = 600 # 600
limit_iterations = 500 # 500
max_routes_to_extract = 100

adjusted_original_score_weight = 1
adjusted_score_weight = 0
readjust_translations = 'false'
steered = 'true' # do the remaining ones for false
filtered = 'true'
guided = 'true'
classifier_guidance = 'reaction_type'
similarity_type = ''
with_starting_material = 'false'
original_score_weight = 1
property_weight = 50
guidance_scale = 1.5 # TODO: 10 is done
min_length_for_guidance = 15
eos_penalty = -10.0
n_candidates_to_evaluate = 53 # 53
num_classes = 11
search_batch_size = 1024
onmt_checkpoint_path = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5',
    'model.product-reactants_step_250000.pt'
)
classifier_checkpoint = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'reaction_type_routes_fraction1.0_thresh500_completion0.8_augmentation5',
    'checkpoint_25.pt'
)

experiment_name = f'{single_step_model}_{model_dir_id}_steered{steered}_filtered{filtered}_guided{guided}_numModelCalls{limit_reaction_model_calls}_{dataset_type}_{time_stamp}' 
# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
script_args = {"script_dir": SCRIPT_DIR,
                "use_torchrun": 'false',
                "args": {
                    'general.experiment_name': experiment_name,
                    'general.seed': seed,
                    'single_step_model.model_type': single_step_model,
                    'single_step_model.model_dir': model_dir,
                    'single_step_model.default_num_results': default_num_results,
                    'search': search,
                    'search.steered': steered,
                    'search.filtered': filtered,
                    'search.guided': guided,
                    'search.dummy_inventory': dummy_inventory,
                    'search.heuristic': heuristic,
                    'search.strategy': strategy,
                    'search.limit_reaction_model_calls': limit_reaction_model_calls,
                    'search.time_limit_s': time_limit_s,
                    'search.limit_iterations': limit_iterations,
                    'search.max_routes_to_extract': max_routes_to_extract,
                    'value_function': value_function,
                    'multi_step_evaluation': multi_step_evaluation,
                    'single_step_evaluation': single_step_evaluation,
                    'single_step_evaluation.forward_model_dir': forward_model_dir,
                    "classifier_guidance": classifier_guidance,
                    "classifier_guidance.similarity_type": similarity_type,
                    "classifier_guidance.experiment_name": experiment_name,
                    "classifier_guidance.checkpoint_path": classifier_checkpoint,
                    "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                    "classifier_guidance.guidance_scale": guidance_scale,
                    "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                    "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                    "classifier_guidance.eos_penalty": eos_penalty,
                    "classifier_guidance.model.num_classes": num_classes,
                    "classifier_guidance.search_batch_size": search_batch_size,
                    "classifier_guidance.with_starting_material": with_starting_material,
                    "classifier_guidance.eval.property_weight": property_weight,
                    "classifier_guidance.eval.original_score_weight": original_score_weight,
                    "classifier_guidance.adjusted_score_weight": adjusted_score_weight,
                    "classifier_guidance.adjusted_original_score_weight": adjusted_original_score_weight,
                    "classifier_guidance.readjust_translations": readjust_translations,
                    # classifier_guidance.eval.property_weight
                    'route_dataset': dataset,
                    'route_dataset.type': dataset_type,
                    'route_dataset.path': route_path,
                    'route_dataset.start_idx': '$start_idx',
                    'route_dataset.end_idx': '$end_idx'
                },
                "variables": {'targets_per_job': targets_per_job, # 5 molecules per job in array
                                'offset': 0,
                                'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                'end_idx': '$((start_idx+targets_per_job))'}}
script_args['script_name'] = 'search.py'
task = 'search'
slurm_args['job_name'] = 'search_' + experiment_name
slurm_args['output_dir'] = os.path.join(slurm_args['output_dir'], f'search{search}', single_step_model, experiment_name)
output = create_and_submit_batch_job(slurm_args, script_args, interactive=slurm_args['interactive'])
