from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

def get_num_jobd_and_num_targets_per_job(data_dir):
    data_dir_path = os.path.join(PROJECT_ROOT, 'data', data_dir)
    def get_num_reactions_in_file(file_path):
        return len([l for l in open(file_path, 'r') if l.strip()]) - 1 # -1 for the header

    sorted_list_of_files = sorted(os.listdir(data_dir_path), key=lambda x: int(x.split("_reaction")[-1].split(".")[0]))
    # get the number of reactions in each file
    num_reactions_per_file = [get_num_reactions_in_file(os.path.join(data_dir_path, file_name)) \
                                for file_name in sorted_list_of_files \
                                    if os.path.isfile(os.path.join(data_dir_path, file_name))]
    num_targets_per_job = 5
    # get the number of jobs in the array needed to process each file for num_targets_per_job
    get_num_jobs_in_array = lambda num_reactions_in_file,num_targets_per_job: int(num_reactions_in_file/num_targets_per_job)+(num_reactions_in_file%num_targets_per_job!=0)
    # set the number of jobs in the array for each file/node depth
    num_jobs_per_file = {idx: get_num_jobs_in_array(num_rxn,num_targets_per_job) for idx, num_rxn in enumerate(num_reactions_per_file)}
    #num_jobs_per_file = {key: value for key, value in num_jobs_per_file.items() if key==1}
    print(f'num_jobs_per_file: {num_jobs_per_file}')
    return num_jobs_per_file, num_targets_per_job

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'puhti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')
use_srun= False
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '06:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, 
    'end_array_job': 0
    # 26 for 0 and 1, 23 to reaction 2, 19 to reaction 3, 15 for 4 and 5
    # 7 for 6, 5 for 7, 3 for 8, 2 for 9 and 10, 0 for 11 to 13
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

experiment = 'tanimoto_rsmiles_uspto_190_fraction1.0_thresh500'
# NOTE: not sure if can define these in the experiment yaml file too because need them to split the runs
data_dir = 'uspto_190/first_reactions_with_targets'
subset_prefix = 'test_linear_routes_classes_reaction'
# NOTE: also not sure how to set these programmatically in hydra
# TODO: probably requires refactoring to set the full path dynamically in the script
model_checkpoint = os.path.join(PROJECT_ROOT, 'checkpoints',\
                                'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5',\
                                 'model.product-reactants_step_250000.pt')
predictor_checkpoint = os.path.join(PROJECT_ROOT, 'checkpoints',\
                                'tanimoto_routes_fraction1.0_thresh500_completion0.8_augmentation5',\
                                 'checkpoint_100.pt')
default_num_results = 100
guidance_scale = 10 # 2 time20250923_103438
min_length_for_guidance = 15 # 10 time20250923_103438
combine_renormalize = 'true'
num_jobs_per_file, num_targets_per_job = get_num_jobd_and_num_targets_per_job(data_dir)
# num_targets_per_job = 1
# num_jobs_per_file = {0: 1}

for file_idx, num_jobs in num_jobs_per_file.items():
    slurm_args['end_array_job'] = num_jobs - 1 if num_jobs > 0 else 0 # -1 because array starts from 0
    subset = f"{subset_prefix}{file_idx}.csv"
    experiment_params = f'{experiment}_time{time_stamp}'
    run_name = f'reaction{file_idx}'
    script_args = {"script_dir": SCRIPT_DIR,
                    "use_torchrun": 'false',
                    "args": {   
                        "+experiment": experiment,
                        "general.experiment_params": experiment_params,
                        "general.experiment_name": run_name,
                        "single_step_evaluation.data_dir": data_dir,
                        "single_step_evaluation.subset": subset,
                        "single_step_evaluation.start_idx": '$start_idx',
                        "single_step_evaluation.end_idx": '$end_idx',
                        "single_step_model.default_num_results": default_num_results,
                        "classifier_guidance.checkpoint_path": predictor_checkpoint,
                        "classifier_guidance.onmt_checkpoint_path": model_checkpoint,
                        "classifier_guidance.combine_renormalize": combine_renormalize,
                        "classifier_guidance.guidance_scale": guidance_scale,
                        "classifier_guidance.min_length_for_guidance": min_length_for_guidance
                    },
                    "variables": {'targets_per_job': num_targets_per_job, # 5 molecules per job in array
                                    'offset': 0,
                                    'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                    'end_idx': '$((start_idx+targets_per_job))'}
                    }
    script_args['script_name'] = 'evaluate_single_step_dataset.py'
    task = 'evaluate_single_step_dataset'
    slurm_args['job_name'] = experiment_params+'_'+run_name
    output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
    print(f'******** submitting for file {file_idx} with {num_jobs} jobs')
