#!/usr/bin/env python3
"""
Local Dataset SFT Training Script for Qwen2.5-1.5B using HuggingFace SFTTrainer
"""

import os
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
import wandb

# Set up wandb project
os.environ["WANDB_PROJECT"] = "qwen-local-sft"
# os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # Log model checkpoints

def format_local_prompt(example, tokenizer):
    """Format local dataset example into Q&A format matching inference"""
    # Remove leading \n from input and append correct_output with eos_token
    input_text = example["input"]
    answer_text = example["correct_output"]
    text = f"{input_text} {answer_text}{tokenizer.eos_token}"
    return {"text": text}

def main():
    model_name = "Qwen/Qwen2.5-1.5B"
    output_dir = "./qwen-refinement-sft"
    dataset_path = "/homes/55/sumeet/qwenma/refinement_data/refinement_data.json"

    # Load local dataset
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    
    # Create HuggingFace dataset
    dataset = Dataset.from_list(data)
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({"pad_token": "<pad>"})

    # Format the dataset
    formatted_dataset = dataset.map(
        lambda x: format_local_prompt(x, tokenizer), 
        remove_columns=dataset.column_names
    )
    
    # Split dataset: use 90% for train, 10% for eval
    train_test_split = formatted_dataset.train_test_split(test_size=0.04, seed=42)
    train_ds = train_test_split["train"]
    eval_ds = train_test_split["test"]

    print(f"Dataset loaded: {len(formatted_dataset)} total examples")
    print(f"Training examples: {len(train_ds)}")
    print(f"Evaluation examples: {len(eval_ds)}")

    # Completion‐only collator
    # This will mask everything up to and including "A:" so loss is computed only on the solution tokens.
    data_collator = DataCollatorForCompletionOnlyLM(
        response_template="Refined Answer:", tokenizer=tokenizer
    )

    # Debug what the collator is actually doing
    print("="*50)
    print("DEBUGGING DATA COLLATOR")
    print("="*50)

    sample = train_ds[0]
    print(f"Full text: {sample['text']}")

    # Manually tokenize the sample (like SFTTrainer does internally)
    tokenized = tokenizer(
        sample['text'], 
        truncation=True, 
        max_length=512, 
        padding=False,
        return_tensors=None
    )

    print(f"Tokenized keys: {tokenized.keys()}")

    # Test the collator on tokenized data
    batch = [tokenized]
    collated = data_collator(batch)

    input_ids = collated['input_ids'][0]
    labels = collated['labels'][0]

    print(f"Input length: {len(input_ids)}")
    print(f"Labels length: {len(labels)}")

    # Show last 15 tokens and their label status
    print("\nLast 15 tokens and label status:")
    for i in range(max(0, len(input_ids)-15), len(input_ids)):
        token_text = tokenizer.decode([input_ids[i]])
        label_status = "MASKED" if labels[i] == -100 else "TRAIN"
        print(f"  {token_text:15} -> {label_status}")

    # Check if EOS is being trained on
    eos_positions = (input_ids == tokenizer.eos_token_id).nonzero()
    if len(eos_positions) > 0:
        eos_pos = eos_positions[-1].item()
        eos_label = labels[eos_pos]
        print(f"\nEOS token at position {eos_pos}")
        print(f"EOS label: {eos_label} ({'TRAIN' if eos_label != -100 else 'MASKED'})")
    else:
        print("\nNO EOS TOKEN FOUND IN INPUT_IDS!")

    print("="*50)

    training_args = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=5, 
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        optim="adamw_torch",
        learning_rate=1e-5,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=10,
        save_strategy="no",  # Changed to save only at end of epoch
        # save_total_limit=1,  # Only keep 1 checkpoint
        eval_strategy="epoch",  # Evaluate at end of epoch
        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported(),
        report_to="wandb",
        run_name="qwen-local-sft",
        load_best_model_at_end=False,
        # metric_for_best_model="loss",
        greater_is_better=False,
        push_to_hub=False,
        max_seq_length=768,
        dataset_text_field="text",
        packing=False,
        dataset_num_proc=4,
        model_init_kwargs={
            "torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        },
    )

    # Initialize trainer with the completion-only data collator
    trainer = SFTTrainer(
        model=model_name,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tokenizer,
        data_collator=data_collator,
    )

    # Train!
    trainer.train()

    # Save model + tokenizer
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Final wandb log
    if trainer.state.is_world_process_zero:
        wandb.log({"training_completed": True})
        wandb.finish()

    print(f"Training completed! Model saved to {output_dir}")

if __name__ == "__main__":
    main()