from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

def get_num_jobd_and_num_targets_per_job(data_dir):
    data_dir_path = os.path.join(PROJECT_ROOT, 'data', data_dir)
    def get_num_reactions_in_file(file_path):
        return len([l for l in open(file_path, 'r') if l.strip()]) - 1 # -1 for the header

    sorted_list_of_files = sorted(os.listdir(data_dir_path), key=lambda x: int(x.split("_reaction")[-1].split(".")[0]))
    # get the number of reactions in each file
    num_reactions_per_file = [get_num_reactions_in_file(os.path.join(data_dir_path, file_name)) \
                                for file_name in sorted_list_of_files \
                                    if os.path.isfile(os.path.join(data_dir_path, file_name))]
    num_targets_per_job = 5
    # get the number of jobs in the array needed to process each file for num_targets_per_job
    get_num_jobs_in_array = lambda num_reactions_in_file,num_targets_per_job: int(num_reactions_in_file/num_targets_per_job)+(num_reactions_in_file%num_targets_per_job!=0)
    # set the number of jobs in the array for each file/node depth
    num_jobs_per_file = {idx: get_num_jobs_in_array(num_rxn,num_targets_per_job) for idx, num_rxn in enumerate(num_reactions_per_file)}
    #num_jobs_per_file = {key: value for key, value in num_jobs_per_file.items() if key==1}
    print(f'num_jobs_per_file: {num_jobs_per_file}')
    return num_jobs_per_file, num_targets_per_job

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')
use_srun= False
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '01:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, 
    'end_array_job': 0
    # 26 for 0 and 1, 23 to reaction 2, 19 to reaction 3, 15 for 4 and 5
    # 7 for 6, 5 for 7, 3 for 8, 2 for 9 and 10, 0 for 11 to 13
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

experiment = 'token_prefix_fraction1.0_thresh500'
# NOTE: not sure if can define these in the experiment yaml file too because need them to split the runs
data_dir = 'uspto_190/reactions_with_starting_material'
subset_prefix = 'test_reactions_with_starting_material.csv'
# NOTE: also not sure how to set these programmatically in hydra
# TODO: probably requires refactoring to set the full path dynamically in the script
model_checkpoint = os.path.join(PROJECT_ROOT, 'checkpoints',\
                                'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5',\
                                 'model.product-reactants_step_250000.pt')
predictor_checkpoint = os.path.join(PROJECT_ROOT, 'checkpoints', \
                                'tanimoto_routes_fraction1.0_thresh500_completion0.8_augmentation5',\
                                 'checkpoint_100.pt')
default_num_results = 100
num_targets_per_job = 1
# num_jobs_per_file, num_targets_per_job = get_num_jobd_and_num_targets_per_job(data_dir)
# num_targets_per_job = 1
# num_jobs_per_file = {0: 1}

#for file_idx, num_jobs in num_jobs_per_file.items():
subset = f"{subset_prefix}.csv"
run_name = f'{experiment}_time{time_stamp}'
script_args = {"script_dir": SCRIPT_DIR,
                "use_torchrun": 'false',
                "args": {   
                    "+experiment": experiment,
                    "general.experiment_name": run_name,
                    "single_step_evaluation.data_dir": data_dir,
                    "single_step_evaluation.subset": subset,
                    "single_step_evaluation.start_idx": '$start_idx',
                    "single_step_evaluation.end_idx": '$end_idx',
                    "single_step_model.default_num_results": default_num_results,
                    "classifier_guidance.checkpoint_path": predictor_checkpoint,
                    "classifier_guidance.onmt_checkpoint_path": model_checkpoint
                },
                "variables": {'targets_per_job': num_targets_per_job, # 5 molecules per job in array
                                'offset': 0,
                                'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                'end_idx': '$((start_idx+targets_per_job))'}
                }
script_args['script_name'] = 'evaluate_single_step_dataset.py'
task = 'evaluate_single_step_dataset'
slurm_args['job_name'] = run_name
output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)

