from nesim.experiments.gpt_neo_125m import GPTNeo125mTraining, GPTNeo125mTrainingConfig
from nesim.utils.json_stuff import load_json_as_dict
from transformers import TrainingArguments
from nesim.losses.nesim_loss import (
    NesimConfig,
)
import os
import argparse
from lightning import seed_everything

seed_everything(0)

parser = argparse.ArgumentParser(
    description="Trains a gpt neo 125m on the wikitext dataset"
)
parser.add_argument(
    "--nesim-config", type=str, help="Path to the nesim config json file"
)
parser.add_argument(
    "--checkpoint-every-n-steps", type=int, help="save checkpoint every n step"
)
parser.add_argument("--num-warmup-steps", type=int, help="number of warmup steps")
parser.add_argument("--batch-size", type=int, help="batch size")
parser.add_argument("--context-length", type=int, help="context window length")
parser.add_argument(
    "--gradient-accumulation-steps",
    type=int,
    help="number of gradient acccumulation steps",
)
parser.add_argument(
    "--apply-nesim-every-n-steps", type=int, help="apply-nesim-every-n-steps"
)
parser.add_argument(
    "--resume-from-checkpoint",
    type=str,
    help="resume from this checkpoint folder",
    default=None,
)
parser.add_argument(
    "--resume-wandb-id",
    type=str,
    help="resume logging to this wandb run ID",
    default=None,
)
parser.add_argument(
    "--dataset-name",
    type=str,
    help="name of dataset to train upon: should be 'wikipedia' or 'openwebtext'",
    required=True,
)

parser.add_argument(
    "--num-train-epochs",
    type=int,
    help="name of epochs for training",
    required=True,
)

parser.add_argument(
    "--learning-rate",
    type=float,
    help="learnig rate of the model after warmup",
    required=True,
)

args = parser.parse_args()

os.environ["WANDB_PROJECT"] = f"iclr-nesim-gpt-neo-125m-{args.dataset_name}"

if args.resume_wandb_id is not None:
    assert (
        args.resume_from_checkpoint is not None
    ), "if --resume-wandb-id is not None, the value for --resume-from-checkpoint should also be provided"


def get_json_filename_from_path(path: str):
    return os.path.basename(path).replace(".json", "")


nesim_config_filename = get_json_filename_from_path(args.nesim_config)
run_name = f"apply_nesim_every_n_steps_{args.apply_nesim_every_n_steps}_nesim_config_{nesim_config_filename}_checkpoint_every_n_steps_{args.checkpoint_every_n_steps}_num_warmup_steps_{args.num_warmup_steps}_batch_size_{args.batch_size}_context_length_{args.context_length}_{args.dataset_name}"

os.system(f"mkdir -p ./checkpoints/{run_name}")

output_dir = f"./checkpoints/{run_name}"

if args.resume_from_checkpoint is not None:
    assert os.path.exists(
        output_dir
    ), "Expected checkpoint dir to already exist when resuming run"

dataset_info = load_json_as_dict("dataset_info.json")

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    evaluation_strategy="steps",
    eval_steps=12_000,
    logging_steps=1,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    num_train_epochs=args.num_train_epochs,
    weight_decay=0.0,
    warmup_steps=args.num_warmup_steps,
    lr_scheduler_type="cosine",
    learning_rate=args.learning_rate,
    save_steps=args.checkpoint_every_n_steps,
    fp16=True,
    push_to_hub=False,
    report_to="wandb",
    run_name=run_name,
    resume_from_checkpoint=args.resume_from_checkpoint,
    save_safetensors=False
    ## XXXX
    # will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir
    # save_total_limit=3,
)

nesim_config = NesimConfig.from_json(args.nesim_config)

config = GPTNeo125mTrainingConfig(
    training_arguments=training_arguments,
    nesim_config=nesim_config,
    context_length=args.context_length,
    bimt_config=None,
    cross_layer_correlation_loss_config=None,
    wandb_log=True,
    apply_nesim_every_n_steps=args.apply_nesim_every_n_steps,
    neighbourhood_cosine_similarity_loss_lower_bound=0.0,
    # dataset_cache_dir="/mindhive/nklab3/users/XXXX-1/datasets/wikitext/"
    tokenized_dataset_path=dataset_info[args.dataset_name]["tokenized_dataset_path"],
    dataset_cache_dir=dataset_info[args.dataset_name]["dataset_cache_dir"],
    dataset_name=args.dataset_name,
    resume_wandb_id=args.resume_wandb_id,
    sample_prompts=[
            "The President of the United States is",
            "An apple a day",
            "The Eiffel Tower is in"
        ],
        num_sample_completion_tokens=4,
        generate_sample_text_every_n_steps = int(args.gradient_accumulation_steps*100)
)

experiment = GPTNeo125mTraining(config=config)
experiment.run()
