from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'puhti' or platform == 'mahti':
    project = 'project_2006950'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2006950/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2006950'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2006950/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')
use_srun = False
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '24:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, 
    'end_array_job': 0
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

classifier_guidance = 'similarity'
similarity_type = 'tanimoto'
classifier_dataset_name = 'uspto_190_with_dot_separator'
# classifier_test_file = 'val_with_tanimoto_fraction0.4_completion0.8_augment5.csv'
classifier_test_file = 'test_linear_routes_with_tanimoto_fraction1.0_completion0.8_augment5.csv'
classifier_checkpoint_path = os.path.join(PROJECT_ROOT, 
                                     'checkpoints', 
                                     'tanimoto_routes_fraction0.4_completion0.8_augmentation5',
                                     'checkpoint_90.pt')
classifier_onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 
                                                'checkpoints', 
                                                'rsmiles_50k_checkpoints',
                                                'USPTO_50K_PtoR.pt')         
experiment_name = f'evaluate_regressor_{classifier_dataset_name}_{classifier_test_file.split(".csv")[0]}_{similarity_type}_time{time_stamp}' 
script_args = {"script_dir": SCRIPT_DIR,
                "use_torchrun": 'false',
                "args": {               
                    "classifier_guidance": classifier_guidance,
                    "classifier_guidance.similarity_type": similarity_type,
                    "classifier_guidance.experiment_name": experiment_name,
                    "classifier_guidance.dataset.dataset_name": classifier_dataset_name,
                    "classifier_guidance.dataset.test_file": classifier_test_file,
                    "classifier_guidance.checkpoint_path": classifier_checkpoint_path,
                    "classifier_guidance.onmt_checkpoint_path": classifier_onmt_checkpoint_path,
                    "classifier_guidance.dataset.start_idx": "$start_idx",
                    "classifier_guidance.dataset.end_idx": "$end_idx"
                },
                "variables": {'targets_per_job': 100000000, # 5 molecules per job in array
                                'offset': 0,
                                'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                'end_idx': '$((start_idx+targets_per_job))'}
            }
script_args['script_name'] = 'evaluate_regressors.py'
task = 'evaluate_regressors'
slurm_args['job_name'] = experiment_name
output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
