#!/usr/bin/env python3
"""
Evaluation script for OpenAI fine-tuned models.
This script allows you to evaluate your fine-tuned model on test data.
"""

import json
import os
import sys
import argparse
from typing import List, Dict, Any
from openai import OpenAI
import time


def load_jsonl(path: str) -> List[Dict[str, Any]]:
    """Load lines from a JSONL file"""
    with open(path, "r") as f:
        return [json.loads(line) for line in f.readlines() if line.strip()]


def load_job_info(job_info_path: str) -> Dict[str, Any]:
    """Load job information from the saved JSON file"""
    with open(job_info_path, "r") as f:
        return json.load(f)


def evaluate_single_example(client: OpenAI, model_id: str, messages: List[Dict[str, str]], 
                          temperature: float = 0.7, max_tokens: int = 500) -> Dict[str, Any]:
    """Evaluate a single example with the fine-tuned model"""
    try:
        response = client.chat.completions.create(
            model=model_id,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens
        )
        
        return {
            "success": True,
            "response": response.choices[0].message.content,
            "usage": {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens
            },
            "finish_reason": response.choices[0].finish_reason
        }
    except Exception as e:
        return {
            "success": False,
            "error": str(e),
            "response": None,
            "usage": None,
            "finish_reason": None
        }


def evaluate_dataset(client: OpenAI, model_id: str, test_data: List[Dict[str, Any]], 
                    temperature: float = 0.7, max_tokens: int = 500, 
                    max_examples: int = None) -> Dict[str, Any]:
    """Evaluate the entire test dataset"""
    
    if max_examples:
        test_data = test_data[:max_examples]
    
    results = []
    total_tokens = 0
    successful_evaluations = 0
    
    print(f"Evaluating {len(test_data)} examples...")
    
    for i, example in enumerate(test_data):
        print(f"Evaluating example {i+1}/{len(test_data)}...")
        
        # Extract messages from the example
        if "messages" in example:
            messages = example["messages"]
        elif "user" in example and "assistant" in example:
            messages = [
                {"role": "user", "content": example["user"]},
                {"role": "assistant", "content": example["assistant"]}
            ]
        else:
            print(f"Warning: Skipping example with unknown format: {list(example.keys())}")
            continue
        
        # Get only the user message for evaluation (remove assistant response)
        user_messages = [msg for msg in messages if msg["role"] == "user"]
        if not user_messages:
            continue
        
        # Evaluate with just the user message
        result = evaluate_single_example(
            client, model_id, user_messages, temperature, max_tokens
        )
        
        # Add context information
        result["example_index"] = i
        result["input_messages"] = user_messages
        
        if result["success"]:
            successful_evaluations += 1
            if result["usage"]:
                total_tokens += result["usage"]["total_tokens"]
        
        results.append(result)
        
        # Rate limiting - be nice to OpenAI's API
        time.sleep(0.1)
    
    # Calculate summary statistics
    summary = {
        "total_examples": len(test_data),
        "successful_evaluations": successful_evaluations,
        "failed_evaluations": len(test_data) - successful_evaluations,
        "success_rate": successful_evaluations / len(test_data) if test_data else 0,
        "total_tokens_used": total_tokens,
        "average_tokens_per_example": total_tokens / successful_evaluations if successful_evaluations > 0 else 0
    }
    
    return {
        "summary": summary,
        "results": results
    }


def save_evaluation_results(results: Dict[str, Any], output_path: str):
    """Save evaluation results to a JSON file"""
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"Evaluation results saved to: {output_path}")


def print_summary(summary: Dict[str, Any]):
    """Print a formatted summary of the evaluation results"""
    print("\n" + "="*50)
    print("EVALUATION SUMMARY")
    print("="*50)
    print(f"Total examples evaluated: {summary['total_examples']}")
    print(f"Successful evaluations: {summary['successful_evaluations']}")
    print(f"Failed evaluations: {summary['failed_evaluations']}")
    print(f"Success rate: {summary['success_rate']:.2%}")
    print(f"Total tokens used: {summary['total_tokens_used']:,}")
    print(f"Average tokens per example: {summary['average_tokens_per_example']:.1f}")
    print("="*50)


def main():
    parser = argparse.ArgumentParser(description="Evaluate OpenAI fine-tuned model")
    parser.add_argument("--job_info", required=True, help="Path to job_info.json file")
    parser.add_argument("--test_data", required=True, help="Path to test data JSONL file")
    parser.add_argument("--output", default="evaluation_results.json", help="Output file for results")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    parser.add_argument("--max_tokens", type=int, default=500, help="Maximum tokens for response")
    parser.add_argument("--max_examples", type=int, help="Maximum number of examples to evaluate")
    
    args = parser.parse_args()
    
    # Check if OpenAI API key is set
    if not os.environ.get("OPENAI_API_KEY"):
        print("Error: OPENAI_API_KEY environment variable is required")
        sys.exit(1)
    
    # Load job information
    try:
        job_info = load_job_info(args.job_info)
        model_id = job_info["finetuned_model"]
        print(f"Using fine-tuned model: {model_id}")
    except Exception as e:
        print(f"Error loading job info: {e}")
        sys.exit(1)
    
    # Load test data
    try:
        test_data = load_jsonl(args.test_data)
        print(f"Loaded {len(test_data)} test examples")
    except Exception as e:
        print(f"Error loading test data: {e}")
        sys.exit(1)
    
    # Initialize OpenAI client
    client = OpenAI()
    
    # Run evaluation
    print("Starting evaluation...")
    evaluation_results = evaluate_dataset(
        client, model_id, test_data, 
        temperature=args.temperature, 
        max_tokens=args.max_tokens,
        max_examples=args.max_examples
    )
    
    # Print summary
    print_summary(evaluation_results["summary"])
    
    # Save results
    save_evaluation_results(evaluation_results, args.output)
    
    print(f"\nEvaluation completed! Results saved to: {args.output}")


if __name__ == "__main__":
    main()
