from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'lumi':
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    puhti_module = None
    venv_path = 'multiguide-lumi-container'
    container = 'multiguide-lumi.sif'
elif platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
else:
    raise ValueError(f'Platform {platform} not supported')
use_srun = False
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '72:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
loss = 'ce'
epochs = 1000
weight_decay = 1e-5
dropout = 0.1
classifier_guidance = 'reaction_type'
dataset_name = 'uspto_190_fraction0.4_threshold10'
# train_file = 'train_debug.csv'
# val_file = 'val_debug.csv' # could try evaluating with augmentation
train_file = 'train_completion0.8_augment5.csv'
val_file = 'val_completion0.8_augment5.csv' # could try evaluating with augmentation
#onmt_model_path = os.path.join(PROJECT_ROOT, 'checkpoints', 'rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5', 'model.product-reactants_step_10000.pt')
onmt_model_path = os.path.join(PROJECT_ROOT, 'checkpoints', \
    'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5', 'model.product-reactants_step_250000.pt')

experiment = 'reaction_type_constant_heuristic'
wandb_mode = 'online' 
subset = train_file.split('.csv')[0]
experiment_name = f"{classifier_guidance}_loss{loss}_dataset{dataset_name}_{subset}_{time_stamp}"
# experiment_name = 'reaction_type_lossce_datasetuspto_190_train_unique_reactions_completion0.8_augment5_20250830_155344'
# resume = True
# resume_path = 'checkpoint_450.pt'
resume = False
resume_path = 'null'
eval_interval = 5

script_args = {"script_dir": SCRIPT_DIR,
               "use_torchrun": 'false',
               "args": {
                    "+experiment": experiment,
                    "general.experiment_name": experiment_name,
                    "classifier_guidance": classifier_guidance,
                    "classifier_guidance.train.resume": resume,
                    "classifier_guidance.train.resume_path": resume_path,
                    "classifier_guidance.train.loss": loss,
                    "classifier_guidance.dataset.dataset_name": dataset_name,
                    "classifier_guidance.dataset.train_file": train_file,
                    "classifier_guidance.dataset.val_file": val_file,
                    "classifier_guidance.train.num_epochs": epochs,
                    "classifier_guidance.train.weight_decay": weight_decay,
                    "classifier_guidance.train.eval_interval": eval_interval,
                    "classifier_guidance.model.dropout": dropout,
                    "classifier_guidance.experiment_name": experiment_name,
                    "classifier_guidance.onmt_checkpoint_path": onmt_model_path,
                    "general.wandb.mode": wandb_mode,
                    "general.wandb.name": experiment_name
               }}
script_args['script_name'] = 'train_property_predictor.py'
slurm_args['job_name'] = experiment_name
output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
# /scratch/project_462000833/multiguide/data/predictors_old