from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.1'
    venv_path = '/projappl/project_2007775/multiguide'
    container = None
else:
    raise ValueError(f'Platform {platform} not supported')

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '00:30:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '20G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
classifier_guidance = 'reaction_type'
classifier_checkpoint_files = ['checkpoint_308.pt']
classifier_checkpoint = os.path.join(PROJECT_ROOT, 
                                     'checkpoints', 
                                     'reaction_type_completion0.8_augmentation5',
                                     classifier_checkpoint_files[0])
target_class_index = 2
# min_lengths = [10, 14, 15, 18, 20, 25]
# guidance_scales = [0.1, 1, 10, 20, 100]

min_lengths = [25]
guidance_scales = [100]

default_num_results = 5
molecule_name = 'benzoin'
onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', 'rsmiles_50k_checkpoints', 'USPTO_50K_PtoR.pt')

# NOTE: here should not modify any of the hydra configs manually...
# NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
for min_length_for_guidance in min_lengths:
    for guidance_scale in guidance_scales:
        experiment_name = f'{molecule_name}_target{target_class_index}_guidance{guidance_scale}_length{min_length_for_guidance}_results{default_num_results}_time{time_stamp}' 
        script_args = {"script_dir": SCRIPT_DIR,
                        "use_torchrun": 'false',
                        "args": {               
                            "classifier_guidance": classifier_guidance,
                            "classifier_guidance.experiment_name": experiment_name,
                            "classifier_guidance.checkpoint_path": classifier_checkpoint,
                            "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                            "search.steered": "true",
                            "single_step_model.model_dir": "rsmiles_50k_checkpoints",
                            "single_step_model.default_num_results": default_num_results,
                            "classifier_guidance.model.num_classes": 11,
                            "classifier_guidance.target_class_index": target_class_index,
                            "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                            "classifier_guidance.guidance_scale": guidance_scale,
                            "classifier_guidance.n_candidates_to_evaluate": 72,
                            "classifier_guidance.search_batch_size": 1024,
                            "single_step_evaluation.forward_model_dir": "forward_rsmiles_mit"
                        }}
        script_args['script_name'] = 'evaluate_single_step_model_on_one_molecule.py'
        task = 'evaluate_single_step_model_on_one_molecule'
        slurm_args['job_name'] = experiment_name
        output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
