import os
from pathlib import Path
from datetime import datetime
from slurm_utils import create_and_submit_batch_job, get_platform_info

data_dir = 'uspto_50k/processed'
start_array_job = 1
end_array_job = 98
default_num_results = 100
targets_per_job = 50

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]
SCRIPT_DIR = 'scripts'
slurm_args = get_platform_info(use_gpu=True)

slurm_args.update({
    'use_srun': True,
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'time': '24:00:00',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'start_array_job': start_array_job, 
    'end_array_job': end_array_job
})
classifier_guidance = 'reaction_type'
similarity_type = ''
with_starting_material = 'false'
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#default_num_results = 1
single_step_evaluation = '50k'
steered = 'true'
guidance_scales = [1.5] # TODO: 10 is done
min_lengths = [10]
eos_penalty = -10.0
n_candidates_to_evaluate = 72
num_classes = 11
search_batch_size = 1024
model_dir = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'rsmiles_50k_checkpoints'
)
onmt_checkpoint_path = os.path.join(
    model_dir, 
    'USPTO_50K_PtoR.pt'
)
classifier_checkpoint = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'reaction_type_completion0.8_augmentation5',
    'checkpoint_308.pt'
)
# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
#data_dir = 'uspto_50k_debug/processed'
subset = 'test_with_tanimoto_weight1.csv'

experiment_group = 'single_step_50k'
experiment_params = 'reaction_type'
experiment_name = f'50k_steered{steered}_guidance{guidance_scales[0]}_length{min_lengths[0]}_results{default_num_results}_candidates{n_candidates_to_evaluate}_time{time_stamp}'

for guidance_scale in guidance_scales:
    for min_length_for_guidance in min_lengths:
        #experiment_name = f'single_step_data{data_dir.split("/")[0]}{data_dir.split("/")[1]}_steered{steered}_guidance{guidance_scale}_length{min_length_for_guidance}_results{default_num_results}_time{time_stamp}' 
        script_args = {"script_dir": SCRIPT_DIR,
                        "use_torchrun": 'false',
                        "args": {       
                            "general.experiment_group": experiment_group,
                            "general.experiment_params": experiment_params, 
                            "general.experiment_name": experiment_name,    
                            "classifier_guidance": classifier_guidance,
                            "classifier_guidance.similarity_type": similarity_type,
                            "classifier_guidance.experiment_name": experiment_name,
                            "classifier_guidance.checkpoint_path": classifier_checkpoint,
                            "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                            "classifier_guidance.guidance_scale": guidance_scale,
                            "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                            "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                            "classifier_guidance.eos_penalty": eos_penalty,
                            "single_step_evaluation": single_step_evaluation,
                            "classifier_guidance.with_starting_material": with_starting_material,
                            "single_step_evaluation.data_dir": data_dir,
                            "single_step_evaluation.subset": subset,
                            "single_step_evaluation.start_idx": '$start_idx' if not slurm_args['interactive'] else 0,
                            "single_step_evaluation.end_idx": '$end_idx' if not slurm_args['interactive'] else 1,
                            "single_step_evaluation.forward_model_dir": "forward_rsmiles_mit",
                            "search.steered": steered,
                            "single_step_model.model_dir": model_dir,
                            "single_step_model.default_num_results": default_num_results,
                            "classifier_guidance.model.num_classes": num_classes,
                            "classifier_guidance.search_batch_size": search_batch_size
                        },
                        "variables":{'targets_per_job': targets_per_job,
                                      'offset': 0,
                                      'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                      'end_idx': '$((start_idx+targets_per_job))'} 
                    }
        script_args['script_name'] = 'evaluate_single_step_model.py'
        slurm_args['job_name'] = experiment_name
        slurm_args['output_dir'] = os.path.join(slurm_args['output_dir'], experiment_group, experiment_params, experiment_name)
        output = create_and_submit_batch_job(slurm_args, script_args, interactive=slurm_args['interactive'])
