import argparse
from utils import get_run_name
import os

"""
args:
--checkpoint-filename # IS A FOLDERNAME
--config-json-filename

files:
train.py
generate_configs.py
generate_slurm_commands.py - saves everything in a ./slurm folder as 0.sh, 1.sh, etc
start_all_slurm_jobs.py
"""
from nesim.utils.json_stuff import load_json_as_dict

import transformers
from transformers import (
    TrainingArguments,
    DataCollatorForLanguageModeling,
    AutoTokenizer,
)
from model import initialize_model
from dataset import load_and_preprocess_data
import json
import argparse
import os

os.environ["WANDB_PROJECT"] = "lora-gpt-neo-125m-dair-ai-emotion"

parser = argparse.ArgumentParser()
parser.add_argument("--config-filename", help="Path to the configuration JSON file")
args = parser.parse_args()

## quick check that  the files exist
assert os.path.exists(args.config_filename)
config = load_json_as_dict(filename=args.config_filename)

hf_model_name = config["hf_model_name"]
hf_tokenizer_name = config["hf_tokenizer_name"]
hf_dataset_name = config["hf_dataset_name"]
lora_r = config["lora_r"]
lora_alpha = config["lora_alpha"]
lora_dropout = config["lora_dropout"]
lora_bias = config["lora_bias"]
lora_task_type = config["lora_task_type"]
lora_target_modules = config["lora_target_modules"]
batch_size = config["batch_size"]
gradient_accumulation_steps = config["gradient_accumulation_steps"]
warmup_steps = config["warmup_steps"]
max_steps = config["max_steps"]
learning_rate = config["learning_rate"]
logging_steps = config["logging_steps"]

run_name = get_run_name(
    config_name=config["name"], checkpoint_name=config["checkpoint_name"]
)
dataset_cache_dir = f"./hf_cache/{config['hf_dataset_name']}"
os.system(f"mkdir -p {dataset_cache_dir}")
os.system(f"mkdir -p ./checkpoints/{run_name}")
output_dir = f"./checkpoints/{run_name}"

tokenizer = AutoTokenizer.from_pretrained(
    hf_tokenizer_name, padding=True, truncation=True
)
lora_model = initialize_model(
    hf_model_name,
    checkpoint_filename=config["checkpoint_filename"],
    lora_r=config["lora_r"],  # attention heads
    lora_alpha=config["lora_alpha"],  # alpha scaling
    lora_dropout=config["lora_dropout"],
    lora_bias=config["lora_bias"],
    lora_task_type=config["lora_task_type"],  # set this for CLM or Seq2Seq
    lora_target_modules=config["lora_target_modules"],
)

train_dataset = load_and_preprocess_data(
    dataset_name=hf_dataset_name,
    tokenizer=tokenizer,
    cache_dir=dataset_cache_dir,
    max_tokens_per_dataset_item=config["max_tokens_per_dataset_item"],
)

trainer = transformers.Trainer(
    model=lora_model,
    train_dataset=train_dataset,
    args=TrainingArguments(
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=warmup_steps,
        max_steps=max_steps,
        learning_rate=learning_rate,
        fp16=True,
        logging_steps=logging_steps,
        output_dir=output_dir,
        report_to="wandb",
        run_name=run_name,
        overwrite_output_dir=True,
        save_total_limit=10,  # save last n checkpoints
        save_strategy="steps",
        save_steps=100,
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

lora_model.config.use_cache = False
trainer.train()
