#!/usr/bin/env python3
"""
Comprehensive evaluation script for QR-Adaptor experiments.

Supports:
- WikiText-2 PPL
- C4 PPL  
- ARC-Easy, ARC-Challenge
- PIQA
- HellaSwag
- WinoGrande
- MMLU (5-shot)
- GSM8K

Usage:
    python eval_comprehensive.py \
        --hqq_model_path path/to/quantized_model \
        --lora_path path/to/lora_adapter \
        --output_file results.json \
        --tasks all
"""

from __future__ import annotations

import argparse
import json
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Any

import torch
from tqdm import tqdm


# ============================================================
# Task Configurations
# ============================================================

TASK_CONFIGS = {
    "wikitext2": {
        "type": "ppl",
        "dataset": "wikitext",
        "subset": "wikitext-2-raw-v1",
        "split": "test",
        "metric": "perplexity",
        "direction": "lower",
    },
    "c4": {
        "type": "ppl",
        "dataset": "c4",
        "subset": "realnewslike",
        "split": "validation",
        "metric": "perplexity",
        "direction": "lower",
    },
    "arc_easy": {
        "type": "lm_eval",
        "task_name": "arc_easy",
        "metric": "acc",
        "direction": "higher",
        "num_fewshot": 0,
    },
    "arc_challenge": {
        "type": "lm_eval",
        "task_name": "arc_challenge",
        "metric": "acc_norm",
        "direction": "higher",
        "num_fewshot": 0,
    },
    "piqa": {
        "type": "lm_eval",
        "task_name": "piqa",
        "metric": "acc",
        "direction": "higher",
        "num_fewshot": 0,
    },
    "hellaswag": {
        "type": "lm_eval",
        "task_name": "hellaswag",
        "metric": "acc_norm",
        "direction": "higher",
        "num_fewshot": 0,
    },
    "winogrande": {
        "type": "lm_eval",
        "task_name": "winogrande",
        "metric": "acc",
        "direction": "higher",
        "num_fewshot": 0,
    },
    "mmlu": {
        "type": "lm_eval",
        "task_name": "mmlu",
        "metric": "acc",
        "direction": "higher",
        "num_fewshot": 5,
    },
    "gsm8k": {
        "type": "lm_eval",
        "task_name": "gsm8k",
        "metric": "acc",
        "direction": "higher",
        "num_fewshot": 8,
    },
}

TASK_GROUPS = {
    "all": list(TASK_CONFIGS.keys()),
    "ppl": ["wikitext2", "c4"],
    "commonsense": ["arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande"],
    "core": ["wikitext2", "mmlu"],
    "benchmark": ["wikitext2", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
}


# ============================================================
# Model Loading
# ============================================================

def load_model_and_tokenizer(
    hqq_model_path: str,
    lora_path: Optional[str] = None,
    device: str = "cuda",
):
    """Load HQQ quantized model with optional LoRA adapter."""
    from transformers import AutoTokenizer
    
    # Check if it's a quantized model or normal model
    if Path(hqq_model_path).exists() and (Path(hqq_model_path) / "qmodel.pt").exists():
        # HQQ quantized model
        from hqq.models.hf.base import AutoHQQHFModel
        print(f"Loading HQQ model from: {hqq_model_path}")
        model = AutoHQQHFModel.from_quantized(hqq_model_path)
    else:
        # Normal model (FP16)
        from transformers import AutoModelForCausalLM
        print(f"Loading FP16 model from: {hqq_model_path}")
        model = AutoModelForCausalLM.from_pretrained(
            hqq_model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
        )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(hqq_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load LoRA adapter if provided
    if lora_path and Path(lora_path).exists():
        from peft import PeftModel
        print(f"Loading LoRA adapter from: {lora_path}")
        model = PeftModel.from_pretrained(model, lora_path)
    
    model.eval()
    return model, tokenizer


# ============================================================
# Perplexity Evaluation
# ============================================================

def evaluate_perplexity(
    model,
    tokenizer,
    dataset_name: str = "wikitext",
    subset: str = "wikitext-2-raw-v1",
    split: str = "test",
    max_samples: int = 200,
    max_length: int = 512,
    stride: int = 256,
    device: str = "cuda",
) -> Dict[str, float]:
    """Evaluate perplexity on a dataset."""
    from datasets import load_dataset
    
    print(f"\nEvaluating perplexity on {dataset_name}/{subset} ({split})...")
    
    # Load dataset
    try:
        if subset:
            dataset = load_dataset(dataset_name, subset, split=split, trust_remote_code=True)
        else:
            dataset = load_dataset(dataset_name, split=split, trust_remote_code=True)
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        return {"perplexity": float('inf'), "avg_loss": float('inf')}
    
    # Get text column
    text_column = "text" if "text" in dataset.column_names else dataset.column_names[0]
    texts = [t for t in dataset[text_column] if t and t.strip()][:max_samples]
    
    if not texts:
        print("No valid texts found in dataset!")
        return {"perplexity": float('inf'), "avg_loss": float('inf')}
    
    total_loss = 0.0
    total_tokens = 0
    
    for text in tqdm(texts, desc=f"PPL Eval ({dataset_name})"):
        if not text.strip():
            continue
        
        # Tokenize
        encodings = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
        )
        input_ids = encodings.input_ids.to(device)
        
        if input_ids.size(1) < 2:
            continue
        
        # Compute loss
        with torch.no_grad():
            try:
                outputs = model(input_ids, labels=input_ids)
                loss = outputs.loss
                
                if not torch.isnan(loss) and not torch.isinf(loss):
                    total_loss += loss.item() * (input_ids.size(1) - 1)
                    total_tokens += input_ids.size(1) - 1
            except Exception as e:
                print(f"Error computing loss: {e}")
                continue
    
    if total_tokens == 0:
        return {"perplexity": float('inf'), "avg_loss": float('inf')}
    
    avg_loss = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_loss)).item()
    
    print(f"  Perplexity: {ppl:.2f}")
    print(f"  Avg Loss: {avg_loss:.4f}")
    
    return {"perplexity": ppl, "avg_loss": avg_loss}


# ============================================================
# LM-Eval Harness Evaluation
# ============================================================

def evaluate_with_lm_eval(
    model,
    tokenizer,
    tasks: List[str],
    batch_size: int = 4,
    device: str = "cuda",
) -> Dict[str, Dict]:
    """Evaluate using lm-evaluation-harness."""
    import lm_eval
    from lm_eval.models.huggingface import HFLM
    
    print(f"\nEvaluating with lm-eval-harness on tasks: {tasks}")
    
    # Create lm-eval model wrapper
    lm_model = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )
    
    # Run evaluation
    results = {}
    for task_name in tasks:
        task_config = TASK_CONFIGS.get(task_name, {})
        if task_config.get("type") != "lm_eval":
            continue
        
        lm_task = task_config.get("task_name", task_name)
        num_fewshot = task_config.get("num_fewshot", 0)
        
        print(f"\n  Running task: {lm_task} ({num_fewshot}-shot)")
        
        try:
            task_results = lm_eval.simple_evaluate(
                model=lm_model,
                tasks=[lm_task],
                num_fewshot=num_fewshot,
                batch_size=batch_size,
            )
            
            if "results" in task_results and lm_task in task_results["results"]:
                results[task_name] = task_results["results"][lm_task]
                metric = task_config.get("metric", "acc")
                value = results[task_name].get(metric, 0)
                print(f"    {metric}: {value:.4f}")
        except Exception as e:
            print(f"    Error evaluating {lm_task}: {e}")
            results[task_name] = {"error": str(e)}
    
    return results


# ============================================================
# Simple MMLU Evaluation (Fallback)
# ============================================================

def evaluate_mmlu_simple(
    model,
    tokenizer,
    num_fewshot: int = 5,
    max_samples: int = 500,
    device: str = "cuda",
) -> Dict[str, float]:
    """Simple MMLU evaluation without lm-eval-harness."""
    from datasets import load_dataset
    import random
    
    print(f"\nEvaluating MMLU ({num_fewshot}-shot, {max_samples} samples)...")
    
    try:
        dataset = load_dataset("cais/mmlu", "all", split="test")
    except:
        try:
            dataset = load_dataset("lukaemon/mmlu", "all", split="test")
        except Exception as e:
            print(f"Failed to load MMLU: {e}")
            return {"accuracy": 0.0}
    
    # Sample
    random.seed(42)
    indices = random.sample(range(len(dataset)), min(max_samples, len(dataset)))
    samples = [dataset[i] for i in indices]
    
    choices = ['A', 'B', 'C', 'D']
    correct = 0
    total = 0
    
    for sample in tqdm(samples, desc="MMLU"):
        question = sample['question']
        options = sample['choices']
        answer_idx = sample['answer']
        
        # Format prompt
        prompt = f"Question: {question}\n"
        for i, opt in enumerate(options):
            prompt += f"{choices[i]}. {opt}\n"
        prompt += "Answer:"
        
        # Tokenize
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        input_ids = inputs.input_ids.to(device)
        
        # Get prediction
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[0, -1, :]
        
        # Get logits for A, B, C, D
        choice_logits = []
        for choice in choices:
            token_ids = tokenizer.encode(choice, add_special_tokens=False)
            if token_ids:
                choice_logits.append(logits[token_ids[0]].item())
            else:
                choice_logits.append(float('-inf'))
        
        pred_idx = max(range(4), key=lambda x: choice_logits[x])
        if pred_idx == answer_idx:
            correct += 1
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    print(f"  MMLU Accuracy: {accuracy:.4f} ({correct}/{total})")
    
    return {"accuracy": accuracy, "correct": correct, "total": total}


# ============================================================
# Main Evaluation Function
# ============================================================

def evaluate_all(
    model,
    tokenizer,
    tasks: List[str],
    max_ppl_samples: int = 200,
    max_benchmark_samples: int = -1,
    batch_size: int = 4,
    device: str = "cuda",
    use_lm_eval: bool = True,
) -> Dict[str, Any]:
    """Run all evaluations."""
    results = {
        "tasks_evaluated": tasks,
        "metrics": {},
        "timing": {},
    }
    
    # Group tasks by type
    ppl_tasks = [t for t in tasks if TASK_CONFIGS.get(t, {}).get("type") == "ppl"]
    lm_eval_tasks = [t for t in tasks if TASK_CONFIGS.get(t, {}).get("type") == "lm_eval"]
    
    # Evaluate perplexity tasks
    for task in ppl_tasks:
        start_time = time.time()
        config = TASK_CONFIGS[task]
        
        ppl_result = evaluate_perplexity(
            model, tokenizer,
            dataset_name=config["dataset"],
            subset=config.get("subset"),
            split=config["split"],
            max_samples=max_ppl_samples,
            device=device,
        )
        
        results["metrics"][task] = ppl_result
        results["timing"][task] = time.time() - start_time
    
    # Evaluate lm-eval tasks
    if lm_eval_tasks:
        if use_lm_eval:
            try:
                start_time = time.time()
                lm_eval_results = evaluate_with_lm_eval(
                    model, tokenizer,
                    tasks=lm_eval_tasks,
                    batch_size=batch_size,
                    device=device,
                )
                
                for task, result in lm_eval_results.items():
                    results["metrics"][task] = result
                
                results["timing"]["lm_eval"] = time.time() - start_time
            except Exception as e:
                print(f"lm-eval failed: {e}")
                print("Falling back to simple evaluation...")
                use_lm_eval = False
        
        if not use_lm_eval:
            # Fallback: simple MMLU
            if "mmlu" in lm_eval_tasks:
                start_time = time.time()
                mmlu_result = evaluate_mmlu_simple(
                    model, tokenizer,
                    max_samples=max_benchmark_samples if max_benchmark_samples > 0 else 500,
                    device=device,
                )
                results["metrics"]["mmlu"] = mmlu_result
                results["timing"]["mmlu"] = time.time() - start_time
    
    # Compute summary metrics
    summary = compute_summary_metrics(results["metrics"])
    results["summary"] = summary
    
    return results


def compute_summary_metrics(metrics: Dict[str, Any]) -> Dict[str, float]:
    """Compute summary metrics from individual results."""
    summary = {}
    
    # PPL metrics
    if "wikitext2" in metrics:
        summary["wikitext2_ppl"] = metrics["wikitext2"].get("perplexity", float('inf'))
    if "c4" in metrics:
        summary["c4_ppl"] = metrics["c4"].get("perplexity", float('inf'))
    
    # Accuracy metrics
    acc_metrics = []
    
    for task in ["arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu", "gsm8k"]:
        if task in metrics:
            task_data = metrics[task]
            if isinstance(task_data, dict):
                # Try different metric keys
                for key in ["acc", "acc_norm", "accuracy"]:
                    if key in task_data:
                        value = task_data[key]
                        if isinstance(value, (int, float)) and value > 0:
                            summary[f"{task}_acc"] = value
                            acc_metrics.append(value)
                            break
    
    # Average accuracy
    if acc_metrics:
        summary["avg_accuracy"] = sum(acc_metrics) / len(acc_metrics)
    
    return summary


# ============================================================
# Main Entry Point
# ============================================================

def main():
    parser = argparse.ArgumentParser(description="Comprehensive evaluation for QR-Adaptor")
    
    parser.add_argument("--hqq_model_path", type=str, required=True,
                        help="Path to HQQ quantized model or HF model ID")
    parser.add_argument("--lora_path", type=str, default=None,
                        help="Path to LoRA adapter (optional)")
    parser.add_argument("--output_file", type=str, required=True,
                        help="Output JSON file for results")
    parser.add_argument("--tasks", type=str, default="core",
                        help="Tasks to evaluate (comma-separated or group name: all, ppl, commonsense, core, benchmark)")
    parser.add_argument("--max_ppl_samples", type=int, default=200,
                        help="Max samples for PPL evaluation")
    parser.add_argument("--max_benchmark_samples", type=int, default=-1,
                        help="Max samples for benchmark evaluation (-1 for all)")
    parser.add_argument("--batch_size", type=int, default=4,
                        help="Batch size for evaluation")
    parser.add_argument("--no_lm_eval", action="store_true",
                        help="Disable lm-eval-harness (use simple evaluation)")
    
    args = parser.parse_args()
    
    # Parse tasks
    if args.tasks in TASK_GROUPS:
        tasks = TASK_GROUPS[args.tasks]
    else:
        tasks = [t.strip() for t in args.tasks.split(",")]
    
    print("=" * 60)
    print("QR-Adaptor Comprehensive Evaluation")
    print("=" * 60)
    print(f"Model: {args.hqq_model_path}")
    print(f"LoRA: {args.lora_path}")
    print(f"Tasks: {tasks}")
    print("=" * 60)
    
    # Load model
    model, tokenizer = load_model_and_tokenizer(
        args.hqq_model_path,
        args.lora_path,
    )
    
    # Run evaluation
    results = evaluate_all(
        model, tokenizer,
        tasks=tasks,
        max_ppl_samples=args.max_ppl_samples,
        max_benchmark_samples=args.max_benchmark_samples,
        batch_size=args.batch_size,
        use_lm_eval=not args.no_lm_eval,
    )
    
    # Add metadata
    results["config"] = {
        "hqq_model_path": args.hqq_model_path,
        "lora_path": args.lora_path,
        "tasks": tasks,
    }
    
    # Save results
    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Print summary
    print("\n" + "=" * 60)
    print("EVALUATION SUMMARY")
    print("=" * 60)
    for key, value in results.get("summary", {}).items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
    print("=" * 60)
    print(f"Results saved to: {args.output_file}")


if __name__ == "__main__":
    main()
