from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'lumi':
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    puhti_module = None
    venv_path = 'property-funnel'
    container = 'property-funnel.sif'
elif platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.1'
    venv_path = 'multiguide'
    container = None
else:
    raise ValueError(f'Platform {platform} not supported')

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '02:00:00',
    'partition': partition,
    'nodes': 1,
    'gpus-per-node': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'puhti_module': puhti_module,
    'container': container,
    'venv_path': venv_path,
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}
search_guided = 'true'
dataset = 'uspto_hard'
experiment_file = 'sa_lambda01.yaml'
search_experiment_name = 'call1000_time1800'
vocab_file = 'vocab_rsmiles_50k.txt'

guidance_scales = [0.1]
n_candidates_to_evaluates = [72]
sigmoid_steepnesses = [2]

classifier_guidances = ['sa_score', 'np_score', 'toxicity', 'yield']
classifier_guidances = ['sa_score']
prediction_thresholds = [3]
property_weights = [5]

classifier_guidance_experiment_names = ['sa_score_startum4_frac0.2_max_augmentations_5_min_samples_5_lossmse_wlosstrue_curriculumfalse_20250514_124301']
classifier_guidance_checkpoint_paths = ['checkpoint_50.pt']
loss = 'mse'
model_log_var = 'false'

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# problem: chemformer, retroknn, gln, localretro
# ok: megan, mhnreact, graph2edits
single_step_models = ['chemformer', 'localretro', 'retroknn', 'gln']
single_step_models = ['graph2edits', 'megan', 'mhnreact', 'neuralsym', 'root_aligned'] 
single_step_models = ['root_aligned']
model_dirs = ['rsmiles_50k_checkpoints']

for classifier_guidance in classifier_guidances:
    for classifier_guidance_experiment_name, classifier_guidance_checkpoint_path in zip(classifier_guidance_experiment_names, classifier_guidance_checkpoint_paths):
        for property_weight in property_weights:
            for prediction_threshold in prediction_thresholds:
                for guidance_scale in guidance_scales:
                    for n_candidates_to_evaluate in n_candidates_to_evaluates:
                        for sigmoid_steepness in sigmoid_steepnesses:
                            for single_step_model, model_dir in zip(single_step_models, model_dirs):
                                experiment_name = f'{dataset}_{single_step_model}_{classifier_guidance}_{prediction_threshold}_{time_stamp}' 
                                dummy_inventory = 'true' if args.interactive else 'false'
                                # NOTE: here should not modify any of the hydra configs manually...
                                # NOTE: unless maybe running ablations in a loop. For now define all configurations of params as experiment files.
                                script_args = {"script_dir": SCRIPT_DIR,
                                            "use_torchrun": 'false',
                                            "args": {'+experiment': experiment_file,
                                                        'single_step_model': single_step_model,
                                                        'single_step_model.model_dir': model_dir,
                                                        'general.experiment_name': experiment_name,
                                                        'search.experiment_name': search_experiment_name,
                                                        'search.guided': search_guided,
                                                        'classifier_guidance': classifier_guidance,
                                                        'classifier_guidance.eval.property_weight': property_weight,
                                                        'classifier_guidance.train.loss': loss,
                                                        'classifier_guidance.train.model_log_var': model_log_var,
                                                        "classifier_guidance.dataset.vocab_file": vocab_file,
                                                        'classifier_guidance.prediction_threshold': prediction_threshold,
                                                        'classifier_guidance.guidance_scale': guidance_scale,
                                                        'classifier_guidance.n_candidates_to_evaluate': n_candidates_to_evaluate,
                                                        'classifier_guidance.sigmoid_steepness': sigmoid_steepness,
                                                        'classifier_guidance.experiment_name': classifier_guidance_experiment_name,
                                                        'classifier_guidance.eval.checkpoint_path': classifier_guidance_checkpoint_path,
                                                        'dataset.start_idx': '$start_idx',
                                                        'dataset.end_idx': '$end_idx',
                                                        'dataset': dataset,
                                                        'search.dummy_inventory': dummy_inventory},
                                            "variables": {'targets_per_job': 1, # 5 molecules per job in array
                                                            'offset': 0,
                                                            'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                                            'end_idx': '$((start_idx+targets_per_job))'}}
                                script_args['script_name'] = 'search.py'
                                task = 'search'
                                slurm_args['job_name'] = f'{task}_{experiment_name}'
                                output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)