from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

use_srun = False
if platform == 'puhti' or platform == 'mahti':
    project = 'project_2007775'
    partition = 'small'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'small'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '12:00:00',
    'partition': partition,
    'nodes': 1,
    'gpus-per-node': 0,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'mem': '10G', # 50G not enough for uspto_full
    'with_containers': True,
    'use_srun': use_srun,
    'container': 'multiguide-lumi.sif',
    'venv_path': 'multiguide-lumi-container',
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0 #37
}
data_dir='uspto_full/raw'
subset='val_no_overlap.csv'
start_idx=0
end_idx=50000000
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

script_args = {"script_dir": SCRIPT_DIR,
                "use_torchrun": 'false',
                "args": {
                    "reaction_dataset.data_dir": data_dir,
                    "reaction_dataset.subset": subset,
                    "reaction_dataset.start_idx": start_idx,
                    "reaction_dataset.end_idx": end_idx,
                    "reaction_dataset.data_dir": data_dir,
                    "reaction_dataset.subset": subset
                }
                }
script_args['script_name'] = 'find_token_distribution.py'
dd = data_dir.split('/')[-1]
ss = subset.split('/')[-1]
slurm_args['job_name'] = f'find_token_distribution_data_dir_{dd}_subset_{ss}_{time_stamp}'
output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)