'''
(1)
train a GPT-2 model on imdb (just a linear layer)
input: dataset name (imdB or etc)
output: a GPT2 checkpoint
this point is to run GPT2 on this for 1 epoch to get the pretrained GPT body in distribution with our data
swtich to SFTTrainer for better results ** previously just used regular trainer since we were doing classification
'''

from datasets import load_dataset, Dataset
import os
from sklearn.model_selection import train_test_split
from trl import SFTTrainer
import wandb
import torch
from transformers import (
    TrainingArguments,
    AutoTokenizer,
    GPT2LMHeadModel,
    DataCollatorForLanguageModeling,
    set_seed,
)

# Set constants
MODEL_NAME = "openai-community/gpt2"  # Base GPT-2 model
OUTPUT_DIR = "/home/miria/CVXDPO/checkpoint_gpt2_e1_march12/"
EPOCHS = 1
data_dir = "/home/miria/CVXDPO/datasets/aclImdb_/train_imdb/"  # IMDb dataset path

set_seed(1024)

# Initialize wandb
wandb.init(
    project="ICML_CVXDPO2",
    name=f"lmbody_{os.path.basename(os.path.normpath(MODEL_NAME))}_{os.path.basename(os.path.normpath(data_dir))}",
)

# Load GPT-2 (with LM head)
print("Loading GPT-2 LM model...")
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)  

# Load GPT-2 Tokenizer
print("Loading GPT-2 tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 does not have a default pad token

# Resize token embeddings if necessary
model.resize_token_embeddings(len(tokenizer))

# Load IMDb dataset (text-only, no labels needed)
def load_imdb_dataset(data_dir):
    texts = []
    for label in ["neg", "pos"]:
        label_dir = os.path.join(data_dir, label)
        for file_name in os.listdir(label_dir):
            file_path = os.path.join(label_dir, file_name)
            with open(file_path, "r", encoding="utf-8") as file:
                texts.append(file.read())
    return texts

texts = load_imdb_dataset(data_dir)
train_texts, test_texts = train_test_split(texts, test_size=0.2, random_state=42)

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_dict({"text": train_texts})
test_dataset = Dataset.from_dict({"text": test_texts})

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

# Tokenize datasets
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)

# Use proper data collator for causal LM
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)  

# Define training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=False,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=100,
    learning_rate=2e-5,
    weight_decay=0.1,
    max_grad_norm=1.0,
    num_train_epochs=EPOCHS,
    save_steps=-1,
    logging_dir="./logs",
    logging_steps=500,  # Reduced for better logging
    report_to="wandb",
)

# Load `SFTTrainer`
print("Loading `SFTTrainer`...")
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    data_collator=data_collator,
)

# Train model
if training_args.do_train:
    print("Start training...")
    trainer.train()
    trainer.save_model(training_args.output_dir)

    # Save tokenizer
    if trainer.is_world_process_zero():
        tokenizer.save_pretrained(training_args.output_dir)

# Finish wandb logging
wandb.finish()



# from gpt_utils import plot_dict # plots loss and perplexity if needed

# # Plot Losses.
# plot_dict(loss_history, start_step=training_args.logging_steps, 
#           step_size=training_args.logging_steps, use_title='Loss', 
#           use_xlabel='Train Steps', use_ylabel='Values', magnify=2)

# print()

# # Plot Perplexities.
# plot_dict(perplexity_history, start_step=training_args.logging_steps, 
#           step_size=training_args.logging_steps, use_title='Perplexity', 
#           use_xlabel='Train Steps', use_ylabel='Values', magnify=2)

# # check if `do_eval` flag is set.
# if training_args.do_eval:
  
#   # capture output if trainer evaluate.
#   eval_output = trainer.evaluate()
#   # compute perplexity from model loss.
#   perplexity = math.exp(eval_output["eval_loss"])
#   print('\nEvaluate Perplexity: {:10,.2f}'.format(perplexity))
# else:
#   print('No evaluation needed. No evaluation data provided, `do_eval=False`!')
  