from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = 'lumi'
project = 'project_462000833'

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '24:00:00',
    'partition': 'small-g',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': True,
    'container': 'property-funnel.sif',
    'venv_path': 'property-funnel',
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}
classifier_guidance = 'toxicity'
wandb_mode = 'online'
scheduler = 'true'
learning_rate = '1e-4' # 1e-2
epochs = 500
vocab_list = ['vocab_rsmiles_50k.txt']
fraction = '1' if 'debug' not in classifier_guidance else '1'

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
loss_list = ['mse', 'nll']
model_log_var_list = ['false', 'true']
weighted_loss_list = ['true', 'false']
curriculum_learning_list = ['false', 'true']
dataset_names = ['startum4_frac1.0_max_augmentations_5_min_samples_5', 'startum4_frac0.8_max_augmentations_5_min_samples_5']

for dataset_name in dataset_names:
    for loss, model_log_var in zip(loss_list, model_log_var_list):
        for weighted_loss, curriculum_learning in zip(weighted_loss_list, curriculum_learning_list):
            for curriculum_learning in curriculum_learning_list:
                for vocab in vocab_list:
                    for dataset_name in dataset_names:
                        experiment_name = f"{classifier_guidance}_{dataset_name}_loss{loss}_wloss{weighted_loss}_curriculum{curriculum_learning}_{time_stamp}"
                        script_args = {"script_dir": SCRIPT_DIR,
                                    "use_torchrun": 'false',
                                    "args": {
                                            "+experiment": "sa_lambda01.yaml",
                                            "classifier_guidance": classifier_guidance,
                                            "general.experiment_name": experiment_name,
                                            "classifier_guidance.train.loss": loss,
                                            "classifier_guidance.train.model_log_var": model_log_var,
                                            "classifier_guidance.train.use_scheduler": scheduler,
                                            "classifier_guidance.train.learning_rate": learning_rate,
                                            "classifier_guidance.train.weighted_loss": weighted_loss,
                                            "classifier_guidance.dataset.dataset_name": dataset_name,
                                            "classifier_guidance.dataset.fraction": fraction,
                                            "classifier_guidance.dataset.vocab_file": vocab,
                                            "classifier_guidance.train.curriculum_learning": curriculum_learning,
                                            "classifier_guidance.train.num_epochs": epochs,
                                            "general.wandb.mode": wandb_mode,
                                            "general.wandb.name": experiment_name
                                    }}
                        script_args['script_name'] = 'train_property_predictor.py'
                        slurm_args['job_name'] = experiment_name
                        output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
                        # /scratch/project_462000833/multiguide/data/predictors_old