#!/usr/bin/env python3
"""
Evaluate HQQ + LoRA models using lm-evaluation-harness.
This provides standard, reproducible MMLU evaluations.

Usage:
    python eval_lm_harness.py \
        --hqq_model_path path/to/quantized_model \
        --lora_path path/to/lora_adapter \
        --output_file results.json \
        --tasks mmlu
"""

import argparse
import gc
import json
import os
from pathlib import Path

import torch


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hqq_model_path", type=str, required=True)
    parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA adapter (optional, for raw model eval)")
    parser.add_argument("--output_file", type=str, required=True)
    parser.add_argument("--tasks", type=str, default="mmlu", help="Tasks to evaluate (comma-separated)")
    parser.add_argument("--num_fewshot", type=int, default=5, help="Number of few-shot examples")
    parser.add_argument("--batch_size", type=int, default=1)  # Reduced from 4 to prevent OOM
    parser.add_argument("--load_in_8bit", action="store_true",
                       help="Load model in 8-bit using BitsAndBytes LLM.int8")
    args = parser.parse_args()

    print("=" * 60)
    print("Evaluating with lm-evaluation-harness")
    print(f"HQQ Model: {args.hqq_model_path}")
    print(f"LoRA Path: {args.lora_path if args.lora_path else '(none - raw model)'}")
    print(f"Tasks: {args.tasks}")
    print("=" * 60)

    # Load HQQ model or FP16 model
    from hqq.models.hf.base import AutoHQQHFModel
    from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
    from peft import PeftModel
    from pathlib import Path
    
    # Clear GPU memory before loading
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Check if this is a HuggingFace hub model (contains / and doesn't exist locally)
    is_hub_model = "/" in args.hqq_model_path and not Path(args.hqq_model_path).exists()
    
    if is_hub_model:
        if args.load_in_8bit:
            # Load with 8-bit BitsAndBytes quantization
            print(f"Loading 8-bit model from HuggingFace Hub: {args.hqq_model_path}")
            bnb_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0,
            )
            model = AutoModelForCausalLM.from_pretrained(
                args.hqq_model_path,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            print(f"Loading FP16 model from HuggingFace Hub: {args.hqq_model_path}")
            model = AutoModelForCausalLM.from_pretrained(
                args.hqq_model_path,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
    else:
        # For local paths, try different formats
        model_path = Path(args.hqq_model_path)
        
        # Check if it's a standard HF model (has model.safetensors or pytorch_model.bin)
        has_safetensors = (model_path / "model.safetensors").exists()
        has_pytorch = (model_path / "pytorch_model.bin").exists()
        has_sharded = any(model_path.glob("model-*.safetensors"))
        
        if has_safetensors or has_pytorch or has_sharded:
            # Try loading as standard HF model first (for OWQ quantized models)
            try:
                print(f"Loading standard HF model from: {model_path}")
                model = AutoModelForCausalLM.from_pretrained(
                    str(model_path),
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True
                )
            except Exception as e:
                print(f"Failed to load as HF model: {e}, trying HQQ format...")
                model = AutoHQQHFModel.from_quantized(str(model_path))
        else:
            # Try HQQ format
            print("Loading HQQ model...")
            model = AutoHQQHFModel.from_quantized(args.hqq_model_path)
    
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(args.hqq_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    if args.lora_path:
        print("Loading LoRA adapter...")
        model = PeftModel.from_pretrained(model, args.lora_path)
    else:
        print("No LoRA adapter specified, evaluating raw model")
    model.eval()
    
    # Clear GPU memory before evaluation
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Use lm-eval
    import lm_eval
    from lm_eval.models.huggingface import HFLM
    
    print("Creating lm-eval model wrapper...")
    lm_model = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=args.batch_size,
    )
    
    # Run evaluation with retry logic for transient network errors
    task_list = [t.strip() for t in args.tasks.split(",")]
    print(f"Running evaluation on tasks: {task_list}")

    import time
    max_retries = 3
    retry_delay = 30  # seconds

    for attempt in range(max_retries):
        try:
            results = lm_eval.simple_evaluate(
                model=lm_model,
                tasks=task_list,
                num_fewshot=args.num_fewshot,
                batch_size=args.batch_size,
            )
            break  # Success, exit retry loop
        except Exception as e:
            error_str = str(e)
            # Check if it's a transient network error
            if any(x in error_str for x in ["502", "503", "504", "Connection", "Timeout", "Gateway"]):
                if attempt < max_retries - 1:
                    print(f"Network error (attempt {attempt + 1}/{max_retries}): {e}")
                    print(f"Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                    continue
            raise  # Re-raise if not a network error or all retries exhausted
    
    # Extract key metrics
    output = {
        "hqq_model_path": args.hqq_model_path,
        "lora_path": args.lora_path if args.lora_path else "(raw model)",
        "tasks": args.tasks,
        "num_fewshot": args.num_fewshot,
        "results": {}
    }
    
    for task_name, task_result in results["results"].items():
        output["results"][task_name] = {
            k: v for k, v in task_result.items()
            if isinstance(v, (int, float, str))
        }
    
    # Save results
    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)
    
    # Track peak memory
    peak_memory_gb = 0.0
    if torch.cuda.is_available():
        peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
    
    print("\n" + "=" * 60)
    print("Evaluation complete!")
    print(f"PEAK_MEMORY_GB: {peak_memory_gb:.3f}")
    print("=" * 60)
    print(json.dumps(output, indent=2))


if __name__ == "__main__":
    main()
