from custom_datasets import ToyCvDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
from huggingface_hub import login
from datasets import Dataset

login()

# Hparams
model_name = 'openai-community/gpt2'
batch_size = 64
num_epochs = 3
use_seeded_prompt = False
dataset_text_field = "sft_text"
train_data_size = 20000
eval_data_size = 20000

def filter_dataset_columns(dataset: Dataset, keep_columns: list) -> Dataset:
    def keep_selected_columns(example):
        return {col: example[col] for col in keep_columns if col in example}
    
    return dataset.map(keep_selected_columns, remove_columns=dataset.column_names)

# Model
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

# Train data
train_dataset_obj = ToyCvDataset(
                tokenizer=tokenizer,
                batch_size=256,
                split="train",
                use_seeded_prompt=use_seeded_prompt,
            )
train_dataset = train_dataset_obj.get_data()
train_dataset = filter_dataset_columns(train_dataset, keep_columns=['sft_text'])

# Prepare eval dataset
eval_dataset_obj = ToyCvDataset(
    tokenizer=tokenizer,
    batch_size=batch_size,
    split="test",  # Assuming you have an 'eval' split in your dataset
    use_seeded_prompt=use_seeded_prompt,
)
eval_dataset = eval_dataset_obj.get_data()
eval_dataset = filter_dataset_columns(eval_dataset, keep_columns=['sft_text'])

# Trim data
train_dataset = train_dataset.select(
    range(min(train_data_size, len(train_dataset))))
eval_dataset = eval_dataset.select(
    range(min(eval_data_size, len(eval_dataset))))

# Init SFT Trainer
sft_config = SFTConfig(
    dataset_text_field=dataset_text_field,
    max_seq_length=32,
    output_dir="/tmp",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
)
trainer = SFTTrainer(
    model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=sft_config,
)

# Do training
trainer.train()

# Save the model to Hugging Face Hub
model_name = "rmcc11/gpt2-toy-cv"  # Replace with your desired model name
trainer.model.push_to_hub(model_name)
tokenizer.push_to_hub(model_name)