import argparse
from datasets import load_from_disk
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from trl import GKDConfig, GKDTrainer
import math

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune a student language model using TRL GKD Trainer.")
    parser.add_argument("--student_model_id", type=str, required=True)
    parser.add_argument("--teacher_model_id", type=str, required=True)
    parser.add_argument('--teacher_lora_path', type=str)
    parser.add_argument("--dataset", type=str, required=True)

    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--target_batch_size", type=int, default=16, help="Target effective batch size")
    parser.add_argument("--max_seq_length", type=int, default=2*1024, help="Max Sequence Length")

    parser.add_argument("--kd_temperature", type=float, default=0.6, help="Temperature for knowledge distillation.")
    parser.add_argument("--jsd_beta", type=float, default=0.9, help="Weight for the JSD loss (0:=forward KL, 1:=reverse KL).")
    parser.add_argument("--gkd_student_on_policy_ratio", type=float, default=1, help="Probability of doing on-policy gen.")
    args = parser.parse_args()
    print("Parsed arguments:", args)

    # Data parallel analysis
    effective_batch_size_per_step = torch.cuda.device_count() * args.per_device_train_batch_size
    gradient_accumulation_steps = max(1, math.ceil(args.target_batch_size / effective_batch_size_per_step))
    print(f"=== Data Parallel Analysis ===")
    print(f"Effective batch size per step: {effective_batch_size_per_step}")
    print(f"Calculated gradient accumulation steps: {gradient_accumulation_steps}")
    print("===============================")

    # Load tokenizer
    print(f"Loading tokenizer from: {args.student_model_id}")
    tokenizer = AutoTokenizer.from_pretrained(args.student_model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Tokenizer pad_token was None, set to eos_token: {tokenizer.eos_token}")
    # Load models
    print(f"Loading student model: {args.student_model_id}")
    student_model = AutoModelForCausalLM.from_pretrained(
        args.student_model_id,
        torch_dtype=torch.bfloat16,
        attn_implementation=('eager' if 'gemma-3' in args.student_model_id else None),
    )
    print(f"Loading teacher model: {args.teacher_model_id}")
    teacher_model = AutoModelForCausalLM.from_pretrained(
        args.teacher_model_id,
        torch_dtype=torch.bfloat16,
        attn_implementation=('eager' if 'gemma-3' in args.teacher_model_id else None),
    )
    # if type(teacher_model).__name__ == "Gemma3ForConditionalGeneration":
        # torch._dynamo.config.disable = True
    if args.teacher_lora_path:
        teacher_model = PeftModel.from_pretrained(teacher_model, args.teacher_lora_path)

    # Load and prepare dataset
    print(f"Loading dataset from: dataset/{args.dataset}")
    dataset = load_from_disk(f"dataset/{args.dataset}")['train']
    prompts = [tokenizer.apply_chat_template(
        ps,
        tokenize = False,
        add_generation_prompt=True,
    ) for ps in dataset['prompt_structures']]
    dataset = dataset.add_column("prompt", prompts)
    dataset = dataset.map(lambda example: {**example, 'completion': ''}) # placeholder for SFT trainer (GKD inherents SFT)
    print(f"Training dataset size: {len(dataset)}")

    # Setup output directory
    output_dir = f"model/{args.student_model_id.split("/")[-1]}_GKD_{args.dataset}_{args.teacher_lora_path.split("/")[-2]}"
    print(f"Output directory: {output_dir}")

    training_args = GKDConfig(
        bf16=True,
        num_train_epochs=1,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        max_length=args.max_seq_length,

        learning_rate=2e-5,
        lr_scheduler_type="cosine",
        warmup_ratio=0.15,
        
        output_dir=output_dir,
        logging_dir=f"{output_dir}/logs",
        logging_steps=1,
        report_to="tensorboard",

        temperature=args.kd_temperature,
        beta=args.jsd_beta,
        lmbda=args.gkd_student_on_policy_ratio,
        seq_kd=args.gdk_teacher_on_policy,
    )

    # Initialize GKD Trainer
    print("Creating GKD Trainer...")
    trainer = GKDTrainer(
        args=training_args,
        model=student_model,
        teacher_model=teacher_model,
        processing_class=tokenizer,
        train_dataset=dataset,
    )

    # Start training
    print("Starting training...")
    trainer.train()

    # Save final model
    print("Saving final model...")
    trainer.save_model()
    tokenizer.save_pretrained(output_dir)
    print(f"Training completed! Model saved to: {output_dir}")