#!/usr/bin/env python3
"""
GSM8K Single-Agent Evaluation Pipeline
Single path: Generator -> Answer

Pipeline:
1. Generator: Question -> Answer (1 output) - Uses vLLM

Evaluates on GSM8K test set and saves results.
"""

import re
import json
import random
import time
import torch
from pathlib import Path
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from datasets import load_dataset
from openai import OpenAI
import requests
import warnings
warnings.filterwarnings('ignore')

class GSM8KvLLMEvaluator:
    """Generic vLLM evaluator for any model component"""
    
    def __init__(self, 
                 base_url: str,
                 model_name: str,
                 component_name: str = "model",
                 api_key: str = "EMPTY"):
        """Initialize the vLLM evaluator with API client"""
        print(f"Connecting to vLLM server for {component_name} at: {base_url}")
        print(f"{component_name.title()} model: {model_name}")
        print("Model type: Base (using completions API)")
        
        self.base_url = base_url
        self.model_name = model_name
        self.component_name = component_name
        
        # Initialize OpenAI client for vLLM
        self.client = OpenAI(
            base_url=base_url,
            api_key=api_key,
        )
        
        # Test connection
        self._test_connection()
        
    def _test_connection(self):
        """Test connection to vLLM server"""
        max_retries = 5
        
        print(f"Testing {self.component_name} server connection...")
        for i in range(max_retries):
            try:
                response = requests.get(f"{self.base_url.replace('/v1', '')}/health", timeout=10)
                if response.status_code == 200:
                    print(f"✓ Successfully connected to {self.component_name} server")
                    return
            except requests.exceptions.RequestException:
                pass
            
            print(f"Attempt {i+1}/{max_retries}: Waiting for {self.component_name} server...")
            time.sleep(5)
        
        raise ConnectionError(f"Could not connect to {self.component_name} server at {self.base_url}.")
    
    def generate_text(self, prompt: str, max_tokens: int = 512, retries: int = 3, stop_tokens: List[str] = None) -> str:
        """Generate text using vLLM completions API with retries"""
        if stop_tokens is None:
            stop_tokens = ["\n\n", "\nQ:", "Question:", "You are", "Q:"]
            
        for attempt in range(retries):
            try:
                response = self.client.completions.create(
                    model=self.model_name,
                    prompt=prompt,
                    max_tokens=max_tokens,
                    temperature=0.2,
                    top_p=0.9,
                    stop=stop_tokens,
                )
                return response.choices[0].text.strip()
                
            except Exception as e:
                if attempt < retries - 1:
                    print(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                    time.sleep(1)
                    continue
                else:
                    print(f"API call failed after {retries} attempts: {str(e)}")
                    return ""


class GSM8KSingleAgentEvaluator:
    """Single-agent evaluator that uses only generator via vLLM"""
    
    def __init__(self, generator_base_url: str, generator_model_name: str):
        """Initialize the single-agent evaluator with one vLLM client"""
        self.generator = GSM8KvLLMEvaluator(generator_base_url, generator_model_name, "generator")
    
    def create_generator_prompt(self, question: str) -> str:
        """Create prompt for generator model"""
        return f"Q: {question}\nA:"
    
    def generate_answer(self, question: str) -> str:
        """Generate answer for a question using generator"""
        prompt = self.create_generator_prompt(question)
        return self.generator.generate_text(prompt, stop_tokens=["\n\n", "\nQ:", "Question:"])
    
    def extract_answer_from_text(self, text: str) -> Optional[float]:
        """Extract numeric answer from generated text"""
        # Remove commas from numbers
        text = text.replace(",", "")
        
        # Look for "The answer is X" pattern
        answer_pattern = r"[Tt]he answer is:?\s*([+-]?\d+(?:\.\d+)?)"
        match = re.search(answer_pattern, text)
        if match:
            try:
                return float(match.group(1))
            except ValueError:
                pass
        
        # Look for #### pattern (unconditionally, as it's a common format)
        if "####" in text:
            after_hash = text.split("####")[-1].strip()
            match = re.search(r'([+-]?\d+(?:\.\d+)?)', after_hash)
            if match:
                try:
                    return float(match.group(1))
                except ValueError:
                    pass
        
        # Find last number after "answer"
        parts = re.split(r'answer', text, flags=re.IGNORECASE)
        if len(parts) > 1:
            last_part = parts[-1]
            numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', last_part)
            if numbers:
                try:
                    return float(numbers[0]) # Take the first number after "answer"
                except ValueError:
                    pass
        
        # Last resort: find any number in the last line
        lines = text.strip().split('\n')
        if lines:
            last_line = lines[-1]
            numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', last_line)
            if numbers:
                try:
                    return float(numbers[-1]) # Take the last number
                except ValueError:
                    pass
        
        return None
    
    def extract_ground_truth(self, solution: str) -> float:
        """Extract ground truth answer from GSM8K solution"""
        parts = solution.split("####")
        if len(parts) >= 2:
            answer_str = parts[-1].strip().replace(",", "")
            match = re.search(r'([+-]?\d+(?:\.\d+)?)', answer_str)
            if match:
                return float(match.group())
        raise ValueError(f"Could not extract answer from solution: {solution}")
    
    def is_correct_answer(self, generated_text: str, ground_truth_solution: str) -> bool:
        """Check if generated answer matches ground truth"""
        try:
            predicted = self.extract_answer_from_text(generated_text)
            ground_truth = self.extract_ground_truth(ground_truth_solution)
            
            if predicted is None:
                return False
            
            return abs(predicted - ground_truth) < 0.01
        except Exception:
            return False


def run_single_evaluation(seed: int, run_num: int, evaluator: GSM8KSingleAgentEvaluator, test_dataset) -> Dict:
    """Run a single evaluation with the given seed"""
    num_questions = 100
    
    # Set random seed for this run
    random.seed(seed)
    torch.manual_seed(seed)
    
    # Select random subset of questions using the seed
    if len(test_dataset) < num_questions:
        questions_to_process_indices = list(range(len(test_dataset)))
    else:
        questions_to_process_indices = random.sample(range(len(test_dataset)), num_questions)
    
    questions_to_process = [test_dataset[i] for i in questions_to_process_indices]
    actual_num_questions = len(questions_to_process)
    
    print(f"  Run {run_num}: Selected {actual_num_questions} random samples for evaluation.")
    
    # Results storage (simplified)
    questions_data = []
    generator_correct = 0
    
    # Iterate over the randomly selected samples with progress bar
    for idx, (original_idx, item) in enumerate(tqdm(zip(questions_to_process_indices, questions_to_process), 
                                                   total=actual_num_questions, 
                                                   desc=f'    Processing questions')):
        question = item['question']
        ground_truth_solution = item['answer']
        
        # Step 1: Generator
        generator_answer = evaluator.generate_answer(question)
        generator_is_correct = evaluator.is_correct_answer(generator_answer, ground_truth_solution)
        if generator_is_correct:
            generator_correct += 1
        
        # Store simplified result
        question_result = {
            'question_id': idx,
            'question': question,
            'ground_truth_solution': ground_truth_solution,
            'generator_answer': generator_answer
        }
        
        questions_data.append(question_result)
    
    # Calculate accuracy
    accuracy = generator_correct / actual_num_questions if actual_num_questions > 0 else 0.0
    
    # Create simplified results
    results = {
        'questions': questions_data,
        'accuracy': accuracy,
        'seed': seed
    }
    
    print(f"  Run {run_num}: Generator accuracy: {accuracy:.3f} ({generator_correct}/{actual_num_questions})")
    
    return results


def main():
    """Main evaluation function"""
    
    # Configuration
    generator_base_url = "http://localhost:8000/v1"
    generator_model_name = "base_qwen_model"
    
    seeds = [45]
    runs_per_seed = 10
    
    print("="*80)
    print("GSM8K Single-Agent Evaluation")
    print(f"Generator: vLLM-based ({generator_base_url}, {generator_model_name})")
    print(f"Seeds: {seeds}")
    print(f"Runs per seed: {runs_per_seed}")
    print(f"Questions per run: 100")
    print("="*80)

    # Load GSM8K test dataset
    print("\nLoading GSM8K test dataset...")
    dataset = load_dataset("openai/gsm8k", "main", trust_remote_code=True)
    test_dataset = dataset["test"]
    print(f"Total test samples: {len(test_dataset)}")
    
    # Initialize evaluator once
    print("\nInitializing evaluator...")
    evaluator = GSM8KSingleAgentEvaluator(
        generator_base_url=generator_base_url,
        generator_model_name=generator_model_name
    )
    
    # Create output directory
    output_dir = Path('untrained_singleagent2')
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\nStarting evaluation...")
    start_time = time.time()
    
    # Run evaluations
    for seed in seeds:
        print(f"\nSeed {seed}:")
        for run_num in range(1, runs_per_seed + 1):
            print(f"  Starting run {run_num}...")
            
            results = run_single_evaluation(seed, run_num, evaluator, test_dataset)
            
            # Save results
            output_file = output_dir / f'seed_{seed}_run_{run_num}.json'
            with output_file.open('w') as f:
                json.dump(results, f, indent=2)
            
            print(f"  Run {run_num} completed. Results saved to: {output_file}")
    
    total_time = time.time() - start_time
    
    print("\n" + "="*80)
    print("ALL EVALUATIONS COMPLETED")
    print("="*80)
    print(f"Total files generated: {len(seeds) * runs_per_seed}")
    print(f"Total time: {total_time:.2f} seconds")
    print(f"Results saved to: {output_dir}/")
    print("="*80)


if __name__ == "__main__":
    main()