from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = 'lumi'
project = 'project_462000833'
partition = 'small-g'

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '24:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '20G', # 50G not enough for uspto_full
    'with_containers': True,
    'container': 'multiguide-lumi.sif',
    'venv_path': 'multiguide-lumi-container',
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0, #37
    'puhti_module': 'pytorch/2.1'
}
classifier_guidance = 'toy_experiment'
wandb_mode = 'online'
scheduler = 'true'
learning_rate = '1e-4' # 1e-2
epochs = 2000

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
loss_list = ['mse']

for loss in loss_list:
    experiment_name = f"{classifier_guidance}_loss{loss}_{time_stamp}"
    script_args = {"script_dir": SCRIPT_DIR,
                    "use_torchrun": 'false',
                    "args": {
                            "+experiment": "toy_experiment.yaml",
                            "classifier_guidance": classifier_guidance,
                            "general.experiment_name": experiment_name,
                            "classifier_guidance.experiment_name": experiment_name,
                            "classifier_guidance.train.loss": loss,
                            "classifier_guidance.train.use_scheduler": scheduler,
                            "classifier_guidance.train.learning_rate": learning_rate,
                            "classifier_guidance.train.num_epochs": epochs,
                            "general.wandb.mode": wandb_mode,
                            "general.wandb.name": experiment_name
                    }}
    script_args['script_name'] = 'toy_experiment.py'
    slurm_args['job_name'] = experiment_name
    output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
    # /scratch/project_462000833/multiguide/data/predictors_old