from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--platform', type=str)
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = args.platform

if platform == 'puhti' or platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpu'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'mahti':
    project = 'project_2007775'
    partition = 'gpusmall'
    with_containers = False
    puhti_module = 'pytorch/2.4'
    venv_path = '/projappl/project_2007775/syntheseus-python-10'
    container = None
elif platform == 'lumi':
    SCRIPT_DIR = 'scripts'
    platform = 'lumi'
    project = 'project_462000833'
    partition = 'small-g'
    with_containers = True
    container = 'multiguide-lumi.sif'
    venv_path = 'multiguide-lumi-container'
    puhti_module = None
else:
    raise ValueError(f'Platform {platform} not supported')
use_srun= False
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '06:00:00',
    'partition': partition,
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '100G', # 50G not enough for uspto_full
    'with_containers': with_containers,
    'use_srun': use_srun,
    'container': container,
    'venv_path': venv_path,
    'puhti_module': puhti_module,
    'start_array_job': 0, 
    'end_array_job': 0
    # 26 for 0 and 1, 23 to reaction 2, 19 to reaction 3, 15 for 4 and 5
    # 7 for 6, 5 for 7, 3 for 8, 2 for 9 and 10, 0 for 11 to 13
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# classifier_checkpoint = os.path.join(PROJECT_ROOT, 
#                                      'checkpoints', 
#                                      'tanimoto_routes_fraction0.4_thresh10_completion0.8_augmentation5',
#                                      'checkpoint_35.pt')
# classifier_checkpoint = os.path.join(PROJECT_ROOT, 
#                                      'checkpoints', 
#                                      'tanimoto_routes_fraction0.4_thresh10_completion0.8_augmentation5',
#                                      'checkpoint_35.pt')
# classifier_checkpoint = os.path.join(PROJECT_ROOT, 
#                                      'checkpoints', 
#                                      'tanimoto_routes_fraction1.0_thresh500_completion0.8_augmentation5',
#                                      'checkpoint_100.pt')
classifier_checkpoint = os.path.join(PROJECT_ROOT, 
                                     'checkpoints', 
                                     'reaction_type_routes_fraction1.0_thresh500_completion0.8_augmentation5',
                                     'checkpoint_25.pt')
# onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', 'rsmiles_50k_checkpoints', 'USPTO_50K_PtoR.pt')
# rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5
# onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', \
# 'rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5', 'model.product-reactants_step_10000.pt')
onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', \
'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5', 'model.product-reactants_step_250000.pt')
# onmt_checkpoint_path = os.path.join(PROJECT_ROOT, 'checkpoints', 'rsmiles_full_no_overlap_checkpoints', 'USPTO_full_no_overlap.pt')
classifier_guidance = 'reaction_type'
with_starting_material = 'false'
guidance_scales = [1.5] # TODO: 10 is done
min_lengths = [10]
eos_penalty = -10.0
n_candidates_to_evaluate = 72
num_classes = 11
search_batch_size = 1024
separator = '.'

# search
steered = 'true'
# model_type = 'neuralsym'
# model_dir = 'neuralsym/model_retro.pt'

# model type
model_type = 'rootaligned'
model_dir = 'rsmiles_uspto_190_fraction1.0_thresh500_dropped_PtoR_aug5'
default_num_results = 100

# evaluation
single_step_evaluation = '50k'
compute_classifier_score = 'true'
data_dir = 'uspto_190/first_reactions_nonlinear_with_targets'
subset_prefix = 'test_nonlinear_routes_classes_reaction'

data_dir_path = os.path.join(PROJECT_ROOT, 'data', data_dir)
def get_num_reactions_in_file(file_path):
    return len([l for l in open(file_path, 'r') if l.strip()]) - 1 # -1 for the header

sorted_list_of_files = sorted(os.listdir(data_dir_path), key=lambda x: int(x.split("_reaction")[-1].split(".")[0]))
# get the number of reactions in each file
num_reactions_per_file = [get_num_reactions_in_file(os.path.join(data_dir_path, file_name)) \
                            for file_name in sorted_list_of_files \
                                if os.path.isfile(os.path.join(data_dir_path, file_name))]
num_targets_per_job = 1
# get the number of jobs in the array needed to process each file for num_targets_per_job
get_num_jobs_in_array = lambda num_reactions_in_file,num_targets_per_job: int(num_reactions_in_file/num_targets_per_job)+(num_reactions_in_file%num_targets_per_job!=0)
# set the number of jobs in the array for each file/node depth
num_jobs_per_file = {idx: get_num_jobs_in_array(num_rxn,num_targets_per_job) for idx, num_rxn in enumerate(num_reactions_per_file)}
num_jobs_per_file = {key: value for key, value in num_jobs_per_file.items() if key==4}
print(f'num_jobs_per_file: {num_jobs_per_file}')
num_jobs_per_file = {0: 1}

#onmt_checkpoint_id = onmt_checkpoint_path.split('/')[-2]
onmt_checkpoint_id = 'fraction1.0_thresh500'
data_dir_id = 'uspto_190_nonlinear'
for file_idx, num_jobs in num_jobs_per_file.items():
    slurm_args['end_array_job'] = num_jobs - 1 if num_jobs > 0 else 0 # -1 because array starts from 0
    subset = f"{subset_prefix}{file_idx}.csv"
    for guidance_scale in guidance_scales:
        for min_length_for_guidance in min_lengths:
            # TODO: add model type here
            experiment_name = f'{model_type}_{onmt_checkpoint_id}_{data_dir_id}_{classifier_guidance}_reaction{file_idx}_steered{steered}_guidance{guidance_scale}_length{min_length_for_guidance}_results{default_num_results}_time{time_stamp}' 
            script_args = {"script_dir": SCRIPT_DIR,
                            "use_torchrun": 'false',
                            "args": {               
                                "classifier_guidance": classifier_guidance,
                                "classifier_guidance.experiment_name": experiment_name,
                                "classifier_guidance.checkpoint_path": classifier_checkpoint,
                                "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                                "classifier_guidance.guidance_scale": guidance_scale,
                                "classifier_guidance.n_candidates_to_evaluate": n_candidates_to_evaluate,
                                "classifier_guidance.min_length_for_guidance": min_length_for_guidance,
                                "classifier_guidance.eos_penalty": eos_penalty,
                                "classifier_guidance.with_starting_material": with_starting_material,
                                "classifier_guidance.dataset.separator": separator,
                                "single_step_evaluation": single_step_evaluation,
                                "single_step_evaluation.data_dir": data_dir,
                                "single_step_evaluation.subset": subset,
                                "single_step_evaluation.start_idx": '$start_idx',
                                "single_step_evaluation.end_idx": '$end_idx',
                                "single_step_evaluation.forward_model_dir": "forward_rsmiles_mit",
                                "single_step_evaluation.compute_classifier_score": compute_classifier_score,
                                "search.steered": steered,
                                "single_step_model.model_dir": model_dir,
                                "single_step_model.default_num_results": default_num_results,
                                "single_step_model.model_type": model_type,
                                "classifier_guidance.model.num_classes": num_classes,
                                #"classifier_guidance.target_class_index": target_class_index,
                                "classifier_guidance.search_batch_size": search_batch_size
                            },
                            "variables": {'targets_per_job': num_targets_per_job, # 5 molecules per job in array
                                          'offset': 0,
                                          'start_idx': '$((offset+(SLURM_ARRAY_TASK_ID * targets_per_job)))',
                                          'end_idx': '$((start_idx+targets_per_job))'}
                        }
            script_args['script_name'] = 'evaluate_single_step_dataset.py'
            task = 'evaluate_single_step_dataset'
            slurm_args['job_name'] = experiment_name
            output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
            print(f'******** submitting for file {file_idx} with {num_jobs} jobs')
