import os
from pathlib import Path
from datetime import datetime
from slurm_utils import create_and_submit_batch_job, get_platform_info

data_dir = 'uspto_50k/processed'
subset = 'test_with_tanimoto_weight1.csv'
experiment_params = 'no_guidance'
start_array_job = 0 # 51
end_array_job = 0 # 51
default_num_results = 100
targets_per_job = 50
n_candidates_to_evaluate = 72
seed = 42 # 42, 101, 90
offset = 0

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]
SCRIPT_DIR = 'scripts'
slurm_args = get_platform_info(use_gpu=True)
slurm_args.update({
    'use_srun': True,
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'time': '24:00:00',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'start_array_job': start_array_job, 
    'end_array_job': end_array_job
})
classifier_guidance = 'reaction_type'
similarity_type = ''
with_starting_material = 'false'
#time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
time_stamp = '20251026_111933'
#default_num_results = 1
single_step_evaluation = '50k'
steered = 'false'
guidance_scales = [0.] # TODO: 10 is done
min_lengths = [0]
eos_penalty = -10.0
num_classes = 11
search_batch_size = 32768 # 1024, 2048, 8192, 16384, 32768
model_type = 'diffalign' # GRAPH2EDITS broken
model_dir = None
# )
# model_dir = os.path.join(
#     'checkpoints', 
#     'rsmiles_50k_checkpoints' # rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5
# )
onmt_checkpoint_path = os.path.join(
    PROJECT_ROOT,
    'checkpoints', 
    'rsmiles_50k_checkpoints',
    'USPTO_50K_PtoR.pt'
)
classifier_checkpoint = os.path.join(
    PROJECT_ROOT, 
    'checkpoints', 
    'reaction_type_completion0.8_augmentation5',
    'checkpoint_308.pt'
)
# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
#data_dir = 'uspto_50k_debug/processed'

experiment_group = 'single_step_50k'
experiment_name = f'{data_dir.split("/")[0]}_seed{seed}_model{model_type}_steered{steered}_guidance{guidance_scales[0]}_length{min_lengths[0]}_results{default_num_results}_candidates{n_candidates_to_evaluate}_time{time_stamp}'

for guidance_scale in guidance_scales:
    for min_length_for_guidance in min_lengths:
        #experiment_name = f'single_step_data{data_dir.split("/")[0]}{data_dir.split("/")[1]}_steered{steered}_guidance{guidance_scale}_length{min_length_for_guidance}_results{default_num_results}_time{time_stamp}' 
        script_args = {"script_dir": SCRIPT_DIR,
                        "use_torchrun": 'false',
                        "args": {       
                            "general.seed": seed,
                            "general.experiment_group": experiment_group,
                            "general.experiment_params": experiment_params, 
                            "general.experiment_name": experiment_name,    
                            "classifier_guidance": classifier_guidance,
                            "classifier_guidance.similarity_type": similarity_type,
                            "classifier_guidance.experiment_name": experiment_name,
                            "classifier_guidance.checkpoint_path": classifier_checkpoint,
                            "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                            "classifier_guidance.guidance_scale": guidance_scale,
                            "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                            "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                            "classifier_guidance.eos_penalty": eos_penalty,
                            "single_step_evaluation": single_step_evaluation,
                            "classifier_guidance.with_starting_material": with_starting_material,
                            "single_step_evaluation.data_dir": data_dir,
                            "single_step_evaluation.subset": subset,
                            "single_step_evaluation.start_idx": '$start_idx' if not slurm_args['interactive'] else offset+(start_array_job*targets_per_job),
                            "single_step_evaluation.end_idx": '$end_idx' if not slurm_args['interactive'] else offset+(start_array_job*targets_per_job)+targets_per_job,
                            "single_step_evaluation.forward_model_dir": "forward_rsmiles_mit",
                            "search.steered": steered,
                            "single_step_model.model_type": model_type,
                            "single_step_model.model_dir": model_dir,
                            "single_step_model.default_num_results": default_num_results,
                            "classifier_guidance.model.num_classes": num_classes,
                            "classifier_guidance.search_batch_size": search_batch_size
                        },
                        "variables":{'targets_per_job': targets_per_job,
                                      'offset': offset,
                                      'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                      'end_idx': '$((start_idx+targets_per_job))'} 
                    }
        script_args['script_name'] = 'evaluate_single_step_model_in_batch.py'
        slurm_args['job_name'] = experiment_name
        slurm_args['output_dir'] = os.path.join(slurm_args['output_dir'], experiment_group, experiment_params, experiment_name)
        output = create_and_submit_batch_job(slurm_args, script_args, interactive=slurm_args['interactive'])
