from transformers import AutoTokenizer, AutoModelForPreTraining, Trainer, TrainingArguments, AutoModelForCausalLM
from peft import get_peft_model

from models import IRTGenerator
from lora_config import GPT_LORA_CONFIG, PHI_LORA_CONFIG
from data_utils import IRTGenDataset


# configs
_type = 'bias'
batch_size = 8
grad_acc = 4
n_epochs = 10
learning_rate = 5e-5
output_dir = f''
logging_steps=10
eval_steps=50
save_steps=50


# loading tokenizer
model_name =  "openai-community/gpt2-large"   # "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.add_tokens(['[UNK]'])


# loading dataset
train_dataset = IRTGenDataset(_type, 'train', tokenizer, data_path='')
val_dataset = IRTGenDataset(_type, 'test', tokenizer, data_path='')
print(len(train_dataset))
print(train_dataset[0])
print(len(val_dataset))
print(val_dataset[0])


# load the model
if 'gpt' in model_name:
    gpt2 = AutoModelForPreTraining.from_pretrained(model_name, use_cache=False)
    gpt2.resize_token_embeddings(len(tokenizer))
    model = IRTGenerator(model=gpt2, num_param_tokens=5)
    model = get_peft_model(model, GPT_LORA_CONFIG)
    save_safetensors = False
elif 'Phi' in model_name:
    phi3 = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    phi3.resize_token_embeddings(len(tokenizer))
    # print(phi3)
    model = IRTGenerator(model=phi3, num_param_tokens=5)
    model = get_peft_model(model, PHI_LORA_CONFIG)
    save_safetensors = True
model.print_trainable_parameters()


# start training
training_args = TrainingArguments(
    evaluation_strategy="steps",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    output_dir=output_dir,
    logging_steps=logging_steps,
    eval_steps=eval_steps,
    save_steps=save_steps,
    save_total_limit=1,
    gradient_accumulation_steps=grad_acc,
    num_train_epochs=n_epochs,
    # deepspeed='ds_cfg.json',
    load_best_model_at_end=True,
    metric_for_best_model='loss',
    label_names=['labels'],
    save_safetensors=save_safetensors
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()