import argparse
import random
from pathlib import Path

import numpy as np
import torch
from scenario_datasets import MergedDataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
)


def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)



if __name__ == "__main__":
    # Parse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()
 
    # # model_name = "EleutherAI/gpt-j-6B"

    # output_dir = f"gpt-sft-long-hh-checkpoint"
    # model_name = "gpt2-large"
    # train_batch_size = 8
    # gradient_accumulation_steps = 1
    # learning_rate = 1e-5
    # eval_batch_size = 8
    # eval_steps = 500
    # max_input_length = 1024
    # save_steps = 5000
    # num_train_epochs = 5
    # random.seed(42)

    # output_dir = f"llama-v1.1-sft-hh-checkpoint"
    # model_name = '/data/private_models/xx_models/llama/llama_hf_weights_v1.1/llama-7b'
    # train_batch_size = 2
    # gradient_accumulation_steps = 4
    # learning_rate = 1e-5
    # eval_batch_size = 8
    # eval_steps = 500
    # max_input_length = 1024
    # save_steps = 10000
    # num_train_epochs = 8
    # random.seed(42)
    
    model_name = args.model_name
    if model_name =="gpt2-large":
        train_batch_size = 8
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 10
    elif model_name =="gpt2-xl":
        train_batch_size = 4
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 8
    elif model_name == "EleutherAI/pythia-410m-deduped":
        train_batch_size = 8
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 10
    elif model_name == "EleutherAI/pythia-1.4b-deduped":
        train_batch_size = 4
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 8
    elif model_name == "EleutherAI/pythia-6.9b-deduped":
        train_batch_size = 2
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 6
    elif model_name == "EleutherAI/pythia-12b-deduped":
        train_batch_size = 2
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 6
    elif "llama" in model_name:
        model_name = f'/data/private_models/xx_models/llama/llama_hf_weights_v1.1/{model_name}'
        train_batch_size = 2
        gradient_accumulation_steps = 1
        learning_rate = 1e-5
        num_train_epochs = 3
    else:
        raise ValueError(f"Unknown model name: {args.model_name}")
    
    output_dir = Path(model_name).stem + '3ep'
    eval_batch_size = train_batch_size
    eval_steps = 5000
    max_input_length = 1024
    save_steps = 100000
    random.seed(42)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if not tokenizer.eos_token:
        tokenizer.eos_token = DEFAULT_EOS_TOKEN
    tokenizer.padding_side = "right"
    tokenizer.truncation_side = "left"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.resize_token_embeddings(len(tokenizer))
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.end_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id

    # Set up the datasets
    train_dataset = MergedDataset(
        train_path="",
        tokenizer=tokenizer,
        split="train",
        max_length=max_input_length,
        train_type="sft",
    )
    dev_dataset = MergedDataset(
        train_path="",
        tokenizer=tokenizer,
        split="test",
        max_length=max_input_length,
        train_type="sft",
    )

    def compute_metrics(eval_preds):
        return 0

    # Create a preprocessing function to extract out the proper logits from the model output
    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    # Prepare the trainer and start training
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="steps",
        eval_accumulation_steps=1,
        learning_rate=learning_rate,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        # gradient_checkpointing=True,
        half_precision_backend=True,
        bf16=True,
        adam_beta1=0.9,
        adam_beta2=0.95,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=num_train_epochs,
        warmup_steps=100,
        eval_steps=eval_steps,
        save_steps=save_steps,
        load_best_model_at_end=True,
        logging_steps=50,
        deepspeed="./ds_config_gpt.json",
        report_to=None
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        data_collator=default_data_collator,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    trainer.train()
    trainer.save_model(output_dir)
