import os
from pathlib import Path
from datetime import datetime
from slurm_utils import create_and_submit_batch_job, get_platform_info

data_dir = 'route_similarity_data/processed'
subset = 'Atorvastatin_reactions.csv'
experiment_params = 'no_guidance'
start_array_job = 0
end_array_job = 2
default_num_results = 100
targets_per_job = 10
seed = 42
model_type = 'rootaligned' # GRAPH2EDITS broken

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]
SCRIPT_DIR = 'scripts'
slurm_args = get_platform_info(use_gpu=True)

slurm_args.update({
    'use_srun': True,
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'time': '06:00:00', # 72 for all 50 targets
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'start_array_job': start_array_job, 
    'end_array_job': end_array_job
})
classifier_guidance = 'reaction_type'
similarity_type = ''
with_starting_material = 'false'
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
#default_num_results = 1
single_step_evaluation = '50k'
steered = 'false'
guidance_scales = [0] # TODO: 10 is done
min_lengths = [0]
eos_penalty = -10.0
n_candidates_to_evaluate = 53
num_classes = 11
search_batch_size = 1024
# model_dir = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints',
#     'neuralsym',
#     'model_retro.pt'
# )
# model_dir = None
model_dir = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5'
)
onmt_checkpoint_path = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5',
    'model.product-reactants_step_250000.pt'
)
classifier_checkpoint = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'reaction_type_routes_fraction1.0_thresh500_completion0.8_augmentation5',
    'checkpoint_25.pt'
)
# model_dir = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints', 
#     'rsmiles_50k_checkpoints' # rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5
# )
# onmt_checkpoint_path = os.path.join(
#     PROJECT_ROOT,
#     'checkpoints', 
#     'rsmiles_50k_checkpoints',
#     'USPTO_50K_PtoR.pt'
# )
# classifier_checkpoint = os.path.join(
#     PROJECT_ROOT, 
#     'checkpoints', 
#     'reaction_type_completion0.8_augmentation5',
#     'checkpoint_308.pt'
# )

# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
#data_dir = 'uspto_50k_debug/processed'

experiment_group = 'manual_synthesis'
experiment_name = f'{subset.split(".")[0]}_routes190_seed{seed}_model{model_type}_steered{steered}_guidance{guidance_scales[0]}_length{min_lengths[0]}_results{default_num_results}_candidates{n_candidates_to_evaluate}_time{time_stamp}'

for guidance_scale in guidance_scales:
    for min_length_for_guidance in min_lengths:
        #experiment_name = f'single_step_data{data_dir.split("/")[0]}{data_dir.split("/")[1]}_steered{steered}_guidance{guidance_scale}_length{min_length_for_guidance}_results{default_num_results}_time{time_stamp}' 
        script_args = {"script_dir": SCRIPT_DIR,
                        "use_torchrun": 'false',
                        "args": {       
                            "general.experiment_group": experiment_group,
                            "general.experiment_params": experiment_params, 
                            "general.experiment_name": experiment_name,    
                            "classifier_guidance": classifier_guidance,
                            "classifier_guidance.similarity_type": similarity_type,
                            "classifier_guidance.experiment_name": experiment_name,
                            "classifier_guidance.checkpoint_path": classifier_checkpoint,
                            "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                            "classifier_guidance.guidance_scale": guidance_scale,
                            "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                            "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                            "classifier_guidance.eos_penalty": eos_penalty,
                            "classifier_guidance.model.num_classes": num_classes,
                            "classifier_guidance.search_batch_size": search_batch_size,
                            "classifier_guidance.with_starting_material": with_starting_material,
                            "single_step_evaluation": single_step_evaluation,
                            "single_step_evaluation.data_dir": data_dir,
                            "single_step_evaluation.subset": subset,
                            "single_step_evaluation.start_idx": '$start_idx' if not slurm_args['interactive'] else 0,
                            "single_step_evaluation.end_idx": '$end_idx' if not slurm_args['interactive'] else 1,
                            "single_step_evaluation.forward_model_dir": "forward_rsmiles_mit",
                            "single_step_model.model_type": model_type,
                            "single_step_model.model_dir": model_dir,
                            "single_step_model.default_num_results": default_num_results,
                            "search.steered": steered
                        },
                        "variables":{'targets_per_job': targets_per_job,
                                      'offset': 0,
                                      'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                      'end_idx': '$((start_idx+targets_per_job))'} 
                    }
        script_args['script_name'] = 'evaluate_single_step_model_in_batch.py'
        slurm_args['job_name'] = experiment_name
        slurm_args['output_dir'] = os.path.join(slurm_args['output_dir'], experiment_group, experiment_params, experiment_name)
        output = create_and_submit_batch_job(slurm_args, script_args, interactive=slurm_args['interactive'])
