import os
import yaml
import argparse
from datetime import datetime
from slurm_utils import create_and_submit_batch_job
from pathlib import Path

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = 'puhti'
project = 'project_2007775'
slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '02:00:00',
    'partition': 'gpu',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '10G', # 50G not enough for uspto_full
    'with_containers': False, # always false for puhti and mahti
    'venv_path': '/projappl/project_2007775/multiguide',
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0, #37
    'puhti_module': 'pytorch/2.1'
}
classifier_guidance = 'toy_experiment'
max_num_expressions = 50000
wandb_mode = 'online'
epochs = 100
print_every = 10
max_num = 3
targets = [0, 1]
guidance_scales = [0., 0.001,0.01, 0.05, 200, 1000]
guidance_scales = [0.3]
min_lengths = [0, 1, 2, 3, 4, 5]
ablation_property_name = 'min_length_for_guidance'
eos_penalties = [0., -0.1, -1, -2, -20, -100]
ablation_property_name = 'eos_penalty'
# TODO: ablation on the eos penalty, on min guidance length, min confidence, and seeds

time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
classifier_experiment_name = f'toy_classifier_20250722_151842'
data_dir = f"toy_experiment_debug_classifier_experiment_20250721_152030"
experiment_name = f'{classifier_experiment_name}_ts_{time_stamp}'
onmt_epoch = 6000
classifier_epoch = 1500
onmt_checkpoint_path = f'{PROJECT_ROOT}/experiments/toy_experiment/{classifier_experiment_name}/onmt/onmt_model_step_{onmt_epoch}.pt'
classifier_checkpoint = f'{PROJECT_ROOT}/experiments/toy_experiment/{classifier_experiment_name}/classifier/classifier_{classifier_epoch}.pt'
for target_class_index in targets:
    for guidance_scale in guidance_scales:
        for ablation_value in eos_penalties:
            script_args = {"script_dir": SCRIPT_DIR,
                            "use_torchrun": 'false',
                            "args": {
                            "+experiment": "toy_experiment.yaml",
                            "classifier_guidance": classifier_guidance,
                            "classifier_guidance.phase": "translate",
                            "classifier_guidance.onmt_checkpoint_path": onmt_checkpoint_path,
                            "classifier_guidance.experiment_name": classifier_experiment_name,
                            "classifier_guidance.guidance_scale": guidance_scale,
                            "classifier_guidance.eos_penalty": ablation_value,
                            "classifier_guidance.checkpoint_path": classifier_checkpoint,
                            "classifier_guidance.dataset.data_dir": data_dir,
                            "classifier_guidance.dataset.max_num_expressions": max_num_expressions,
                            "classifier_guidance.dataset.max_num": max_num,
                            "classifier_guidance.train.num_epochs": epochs,
                            "classifier_guidance.train.print_every": 1,
                            "classifier_guidance.are_scores_close_debug": False,
                            "classifier_guidance.are_scores_close_tolerance": 1e-1,
                            "classifier_guidance.target_class_index": target_class_index,
                            "general.wandb.mode": wandb_mode,
                            "general.wandb.name": experiment_name
                        }}
            script_args['script_name'] = 'toy_experiment.py'
            slurm_args['job_name'] = f'{classifier_experiment_name}_target_class_{target_class_index}_{ablation_property_name}_{ablation_value}_ts_{time_stamp}'
            output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive)
