# train.py
# This script is optional and used if any supervised finetuning is required for Shapley NEAR (e.g., fine-tuning models on hallucination-labeled data)

import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True, help="Model name or path")
    parser.add_argument("--dataset", type=str, default="coqa", help="Dataset name")
    parser.add_argument("--output_dir", type=str, default="./finetuned_model")
    args = parser.parse_args()

    model = AutoModelForCausalLM.from_pretrained(args.model)
    tokenizer = AutoTokenizer.from_pretrained(args.model)

    dataset = load_dataset(args.dataset)
    tokenized = dataset.map(lambda x: tokenizer(x["context"], truncation=True, padding=True), batched=True)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=4,
        num_train_epochs=1,
        save_total_limit=1,
        logging_dir="./logs",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["validation"] if "validation" in tokenized else None,
    )

    trainer.train()
    model.save_pretrained(args.output_dir)


if __name__ == "__main__":
    main()
