import argparse
import torch
import re
import random
from datasets import load_from_disk
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model
import math

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train defensive teacher model using GRPO.")
    parser.add_argument("--teacher_model_id", type=str, required=True)
    parser.add_argument("--student_model_id", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)

    parser.add_argument("--per_device_train_batch_size", type=int, default=4)
    parser.add_argument("--target_batch_size", type=int, default=128, help="Target effective batch size")
    parser.add_argument("--max_seq_length", type=int, default=2*1024, help="Max Sequence Length")

    parser.add_argument("--use_lora", action="store_true", help="Use LoRA adapter for teacher model")
    parser.add_argument("--lora_r", type=int, default=128, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=128, help="LoRA alpha")
    parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout")

    parser.add_argument("--coe", type=int, default=1)
    parser.add_argument("--name_tag", type=str, default="")
    args = parser.parse_args()
    print("Parsed arguments:", args)

    accelerator = Accelerator()
    if accelerator.state.deepspeed_plugin is not None:
        accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
    
    effective_batch_size_per_step = accelerator.num_processes * args.per_device_train_batch_size
    gradient_accumulation_steps = max(1, math.ceil(args.target_batch_size / effective_batch_size_per_step))
    print(f"=== Data Parallel Analysis ===")
    print(f"Effective batch size per step: {effective_batch_size_per_step}")
    print(f"Calculated gradient accumulation steps: {gradient_accumulation_steps}")
    print("===============================")

    print(f"Loading teacher model: {args.teacher_model_id}")
    teacher_tokenizer = AutoProcessor.from_pretrained(
        args.teacher_model_id, 
        use_fast=(True if 'gemma-3' in args.teacher_model_id else None)
    )
    teacher_model = AutoModelForCausalLM.from_pretrained(
        args.teacher_model_id,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        attn_implementation=('eager' if 'gemma-3' in args.teacher_model_id else None),
    )
    if type(teacher_model).__name__ == "Gemma3ForConditionalGeneration":
        teacher_tokenizer = teacher_tokenizer.tokenizer
        # torch._dynamo.config.disable = True
    if teacher_tokenizer.pad_token is None:
        teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
    
    if args.use_lora:
        print("Applying LoRA adapter to teacher model...")
        lora_config = LoraConfig(
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            target_modules="all-linear",
            lora_dropout=args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
        )
        teacher_model = get_peft_model(teacher_model, lora_config)
        teacher_model.print_trainable_parameters()
        print("LoRA adapter applied successfully!")

    # Load student model for reward computation
    print(f"Loading student model: {args.student_model_id}")
    student_tokenizer = AutoProcessor.from_pretrained(args.student_model_id)
    student_model = AutoModelForCausalLM.from_pretrained(
        args.student_model_id,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
    )
    if type(student_model).__name__ == "Gemma3ForConditionalGeneration":
        student_tokenizer = student_tokenizer.tokenizer
        # torch._dynamo.config.disable = True
    if student_tokenizer.pad_token is None:
        student_tokenizer.pad_token = student_tokenizer.eos_token
    
    print("Preparing student model with accelerator...")
    student_model.eval()
    student_model = student_model.to(accelerator.device)
    print(f"Student model prepared successfully with accelerator state: {accelerator.state}")

    dataset = load_from_disk(f"dataset/{args.dataset}")['train']
    prompts = [teacher_tokenizer.apply_chat_template(
        ps,
        tokenize = False,
        add_generation_prompt=True,
    ) for ps in dataset['prompt_structures']]
    dataset = dataset.add_column("prompt", prompts)
    print(f"Loaded dataset & built prompt from chat template on {len(dataset)} examples.")

    def correctness_reward(prompts, completions, **kwargs):
        ground_truth = kwargs.get('ground_truth', [])
        if not ground_truth:
            print("Warning: No ground truth found in kwargs")
            print("Available kwargs keys:", list(kwargs.keys()))

        # Compute correctness rewards for all completions
        correctness_rewards = []
        for i, completion in enumerate(completions):
            match = re.search(r'<answer>(.*?)</answer>', completion)
            if match:
                candidate = match.group(1).strip()
                try:
                    candidate_int = int(candidate)
                    correctness_reward = 1.0 if candidate_int == ground_truth[i] else 0.0
                    correctness_rewards.append(correctness_reward)
                except:
                    print(f"No parse: {candidate}")
                    correctness_rewards.append(0.0)
            else:
                print(f"No answer: ..{completion[-15:].replace('\n', ' ')}")
                correctness_rewards.append(0.0)

        print(f"\t{len(correctness_rewards)} mean correctness={sum(correctness_rewards)/len(correctness_rewards):.2f}")
        return correctness_rewards
    
    def perplexity_reward(prompts, completions, **kwargs):
        perplexity_rewards = []
        full_texts = [prompt + completion for prompt, completion in zip(prompts, completions)]
        with torch.no_grad():
            full_inputs = student_tokenizer(
                full_texts,
                return_tensors="pt",
                truncation=True,
                max_length=args.max_seq_length,
                padding=True
            ).to(accelerator.device)
            
            prompt_inputs = student_tokenizer(
                prompts,
                return_tensors="pt",
                truncation=True,
                max_length=args.max_seq_length,
                padding=True
            ).to(accelerator.device)
            
            outputs = student_model(**full_inputs)
            logits = outputs.logits

            for i in range(len(prompts)):
                prompt_length = (prompt_inputs['attention_mask'][i] == 1).sum().item()                
                completion_logits = logits[i, prompt_length-1:-1, :]
                completion_labels = full_inputs['input_ids'][i, prompt_length:]
                completion_attention = full_inputs['attention_mask'][i, prompt_length:]
                
                if completion_labels.size(0) == 0:
                    perplexity_rewards.append(0.0)
                    continue
                completion_loss = torch.nn.functional.cross_entropy(
                    completion_logits, 
                    completion_labels, 
                    reduction='none'
                )
                
                completion_length = completion_attention.sum().item()
                if completion_length > 0:
                    masked_loss = completion_loss * completion_attention.float()
                    mean_loss = masked_loss.sum() / completion_length
                    perplexity = mean_loss
                else:
                    perplexity = 0.0
                
                perplexity_rewards.append(perplexity * args.coe)
                
        print(f"\t{len(perplexity_rewards)} mean perplexity={sum(perplexity_rewards)/len(perplexity_rewards):.2f}")
        return perplexity_rewards

    output_dir = f"model/{args.teacher_model_id.split('/')[-1]}_{args.name_tag if args.name_tag else 'defensive'}_{args.dataset}"
    if args.use_lora:
        output_dir += "_lora"
    print(f"Output directory: {output_dir}")

    training_args = GRPOConfig(
        bf16=True,
        num_train_epochs=10,
        max_steps=100,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        
        num_generations=8,
        max_completion_length=args.max_seq_length,
        temperature=1,
        use_vllm=True,
        vllm_mode="colocate",
        vllm_gpu_memory_utilization=0.37,
        vllm_tensor_parallel_size=1,
        generation_kwargs={
            "temperature": 1,
            "max_tokens": args.max_seq_length,
        },

        loss_type="dr_grpo",
        scale_rewards="batch",
        learning_rate=2e-5,
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        
        logging_dir=f"{output_dir}/logs",
        logging_steps=1,
        report_to="tensorboard",

        output_dir=output_dir,
        save_steps=50,
    )
    
    # Create GRPO trainer
    print("Creating GRPO trainer...")
    trainer = GRPOTrainer(
        model=teacher_model,
        processing_class=teacher_tokenizer,
        reward_funcs=[correctness_reward, perplexity_reward],
        args=training_args,
        train_dataset=dataset,
    )
    
    # Train
    print("Starting training...")
    trainer.train()

    print(f"Saving final model...")
    trainer.save_model()
    teacher_tokenizer.save_pretrained(output_dir)    
    print(f"Training completed! Model saved to: {output_dir}")