from slurm_utils import create_and_submit_batch_job
from datetime import datetime
import os
from pathlib import Path
import argparse
import yaml

PROJECT_ROOT = Path(os.path.realpath(__file__)).parents[1]

parser = argparse.ArgumentParser()
parser.add_argument('--interactive', action='store_true')
args = parser.parse_args()
# DATA_DIR = 'uspto_full'
SCRIPT_DIR = 'scripts'
platform = 'puhti'
project = 'project_2007775'

slurm_args = {
    'job_dir': 'jobs',
    'job_ids_file': 'job_ids.txt',
    'output_dir': 'output',
    'platform': platform,
    'project': project,
    'time': '04:00:00',
    'partition': 'gpu',
    'nodes': 1,
    'ntasks-per-node': 1,
    'cpus-per-task': 1,
    'gpus-per-node': 1,
    'mem': '10G', # 50G not enough for uspto_full
    'with_containers': False, # always false for puhti and mahti
    'venv_path': '/projappl/project_2007775/multiguide',
    'start_array_job': 0, # 5 to 37
    'end_array_job': 0, #37
    'puhti_module': 'pytorch/2.1'
}
time_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

data_dir = 'toy_experiment_debug_classifier_experiment_20250720_145125'
#data_dir = 'toy_experiment_debug_classifier_experiment_20250720_150040'
data_dir = 'toy_experiment_debug_classifier_experiment_20250720_152522'
data_dir = 'toy_experiment_debug_classifier_experiment_20250720_154344'
data_dir = 'toy_experiment_debug_classifier_experiment_20250721_152030'
#experiment_name = 'toy_experiment_debug_classifier_experiment_20250721_152030_smaller_net'
path_src = f'{PROJECT_ROOT}/data/toy_experiment/{data_dir}/train.src'
path_tgt = f'{PROJECT_ROOT}/data/toy_experiment/{data_dir}/train.tgt'
valid_path_src = f'{PROJECT_ROOT}/data/toy_experiment/{data_dir}/train.src'
valid_path_tgt = f'{PROJECT_ROOT}/data/toy_experiment/{data_dir}/train.tgt'
# load train.yml and override the data.corpus_1.path_src and data.corpus_1.path_tgt
with open(f'{PROJECT_ROOT}/checkpoints/toy_experiment/training/train.yml', 'r') as f:
    train_yml = yaml.safe_load(f)
#train_yml['train_from'] = f'{PROJECT_ROOT}/checkpoints/toy_experiment/{experiment_name}/onmt_model_step_2000.pt'
train_yml['data']['corpus_1']['path_src'] = path_src
train_yml['data']['corpus_1']['path_tgt'] = path_tgt
train_yml['data']['valid']['path_src'] = valid_path_src
train_yml['data']['valid']['path_tgt'] = valid_path_tgt
train_yml['gpus'] = 1
gpu_ranks = '-gpu_ranks 0'
train_yml['train_steps'] = 10000
train_yml['valid_steps'] = 1000
train_yml['warmup_steps'] = 2000
train_yml['report_every'] = 100
train_yml['word_vec_size'] = 256
train_yml['rnn_size'] = 256
train_yml['layers'] = 16
train_yml['transformer_ff'] = 1024
train_yml['heads'] = 8
train_yml['learning_rate'] = 0.5
train_yml['label_smoothing'] = 0.1
# all_other_params from here:
train_yml['max_grad_norm'] = 1.0
train_yml['batch_size'] = 16
train_yml['accum_count'] = 4
# word_vec_size: 128
# rnn_size: 128
# layers: 16
# transformer_ff: 1024
# heads: 8
train_yml['src_vocab'] = f'{PROJECT_ROOT}/data/toy_experiment/{data_dir}/vocab.txt'
train_yml['tgt_vocab'] = f'{PROJECT_ROOT}/data/toy_experiment/{data_dir}/vocab.txt'
#experiment_name = f"onmt_train_learning_rate_{train_yml['learning_rate']}"
experiment_name = f"onmt_train_layers_24_heads_16_time_stamp_{time_stamp}"
train_yml['save_data'] = f'{PROJECT_ROOT}/data/toy_experiment/{experiment_name}'
out_dir = f'{PROJECT_ROOT}/checkpoints/toy_experiment/{experiment_name}'
os.makedirs(out_dir, exist_ok=True)
train_yml['save_model'] = out_dir + '/onmt_model'
out_train_yml = out_dir + '/train.yml'
with open(out_train_yml, 'w') as f:
    yaml.dump(train_yml, f)
# NOTE: need to remove the time stamp to use the same data in multiple experiments
# add this as an env variable: export PATH=/projappl/project_2007775/multiguide/bin:$PATH
script_args = {"script_dir": '',
                "use_torchrun": 'false',
                "args": {}
              }
script_args['script_name'] = f'python3 -m onmt.bin.train -config ../checkpoints/toy_experiment/{experiment_name}/train.yml {gpu_ranks}'
slurm_args['job_name'] = experiment_name
output = create_and_submit_batch_job(slurm_args, script_args, interactive=args.interactive, with_python=False)
# /scratch/project_462000833/multiguide/data/predictors_old