#!/usr/bin/env python3
"""
Evaluate HQQ quantized models with LoRA adapters on WikiText-2 PPL and MMLU.

Usage:
    python eval_hqq_model.py \
        --hqq_model_path experiments/motivating_example/quantized_models/config_D_ours \
        --lora_path experiments/motivating_example/checkpoints/config_D_ours \
        --output_file experiments/motivating_example/results/config_D_ours.json \
        --eval_mmlu
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import torch
from datasets import load_dataset
from tqdm import tqdm


def load_hqq_model(hqq_model_path: str):
    """Load HQQ quantized model."""
    from hqq.models.hf.base import AutoHQQHFModel

    print(f"Loading HQQ model from: {hqq_model_path}")
    model = AutoHQQHFModel.from_quantized(hqq_model_path)
    model.eval()
    return model


def load_lora_adapter(model, lora_path: str):
    """Load LoRA adapter onto the model."""
    from peft import PeftModel
    
    print(f"Loading LoRA adapter from: {lora_path}")
    model = PeftModel.from_pretrained(model, lora_path)
    model.eval()
    return model


def evaluate_ppl(model, tokenizer, device='cuda', max_samples=50):
    """Evaluate perplexity on WikiText-2."""
    print("Evaluating WikiText-2 perplexity...")
    
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    # Filter out empty texts
    texts = [t for t in dataset["text"] if t.strip()]
    
    total_loss = 0.0
    total_tokens = 0
    
    for text in tqdm(texts[:max_samples], desc="Calculating PPL"):
        if not text.strip():
            continue
            
        # Tokenize with limited length
        encodings = tokenizer(
            text, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512
        )
        input_ids = encodings.input_ids.to(device)
        
        if input_ids.size(1) < 2:
            continue
        
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            
        if torch.isnan(loss) or torch.isinf(loss):
            continue
            
        total_loss += loss.item() * (input_ids.size(1) - 1)
        total_tokens += input_ids.size(1) - 1
    
    if total_tokens == 0:
        print("No valid tokens to evaluate!")
        return float('inf')
        
    avg_loss = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_loss)).item()
    print(f"WikiText-2 PPL: {ppl:.2f}")
    return ppl


def evaluate_mmlu(model, tokenizer, device='cuda', num_few_shot=5, max_samples=-1):
    """Evaluate MMLU accuracy with few-shot prompting.
    
    Args:
        max_samples: Number of samples to evaluate. -1 or None for full evaluation.
    """
    print(f"Evaluating MMLU ({num_few_shot}-shot)...")
    
    # Load MMLU dataset
    try:
        dataset = load_dataset("cais/mmlu", "all", split="test")
    except Exception as e:
        print(f"Failed to load MMLU dataset: {e}")
        print("Trying alternative MMLU source...")
        try:
            dataset = load_dataset("lukaemon/mmlu", "all", split="test")
        except:
            print("Could not load MMLU dataset, skipping...")
            return None
    
    print(f"MMLU dataset size: {len(dataset)}")
    
    # Sample examples or use full dataset
    if max_samples is not None and max_samples > 0 and len(dataset) > max_samples:
        import random
        random.seed(42)
        indices = random.sample(range(len(dataset)), max_samples)
        samples = [dataset[i] for i in indices]
        print(f"Using {max_samples} sampled examples")
    else:
        samples = list(dataset)
        print(f"Using full dataset: {len(samples)} examples")
    
    correct = 0
    total = 0
    choices = ['A', 'B', 'C', 'D']
    
    for sample in tqdm(samples, desc="MMLU Evaluation"):
        question = sample['question']
        options = sample['choices']
        answer_idx = sample['answer']  # 0, 1, 2, or 3
        
        # Format prompt
        prompt = f"Question: {question}\n"
        for i, opt in enumerate(options):
            prompt += f"{choices[i]}. {opt}\n"
        prompt += "Answer:"
        
        # Tokenize
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        input_ids = inputs.input_ids.to(device)
        
        # Get logits for the next token
        with torch.no_grad():
            outputs = model(input_ids)
            next_token_logits = outputs.logits[0, -1, :]
        
        # Get logits for A, B, C, D tokens
        choice_logits = []
        for choice in choices:
            token_ids = tokenizer.encode(choice, add_special_tokens=False)
            if token_ids:
                choice_logits.append(next_token_logits[token_ids[0]].item())
            else:
                choice_logits.append(float('-inf'))
        
        # Get prediction
        pred_idx = max(range(4), key=lambda x: choice_logits[x])
        
        if pred_idx == answer_idx:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"MMLU Accuracy: {accuracy:.4f} ({correct}/{total})")
    return accuracy


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--hqq_model_path", type=str, required=True)
    ap.add_argument("--lora_path", type=str, required=True, help="Path to trained LoRA adapter")
    ap.add_argument("--output_file", type=str, required=True)
    ap.add_argument("--max_samples", type=int, default=100)
    ap.add_argument("--eval_mmlu", action="store_true", help="Also evaluate MMLU")
    ap.add_argument("--mmlu_samples", type=int, default=200, help="Number of MMLU samples")
    args = ap.parse_args()

    from transformers import AutoTokenizer

    print(f"=" * 60)
    print(f"Evaluating HQQ Model + LoRA")
    print(f"HQQ Model: {args.hqq_model_path}")
    print(f"LoRA Path: {args.lora_path}")
    print(f"=" * 60)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.hqq_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load base quantized model
    model = load_hqq_model(args.hqq_model_path)
    
    # Load LoRA adapter
    model = load_lora_adapter(model, args.lora_path)

    results = {
        "hqq_model_path": args.hqq_model_path,
        "lora_path": args.lora_path,
    }

    # Evaluate PPL
    ppl = evaluate_ppl(model, tokenizer, max_samples=args.max_samples)
    results["wikitext2_ppl"] = ppl

    # Evaluate MMLU if requested
    if args.eval_mmlu:
        mmlu_acc = evaluate_mmlu(model, tokenizer, max_samples=args.mmlu_samples)
        if mmlu_acc is not None:
            results["mmlu_accuracy"] = mmlu_acc

    # Save results
    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n{'=' * 60}")
    print(f"Evaluation complete!")
    print(f"Results saved to: {args.output_file}")
    print(json.dumps(results, indent=2))
    print(f"{'=' * 60}")


if __name__ == "__main__":
    main()
