#!/usr/bin/env python3
"""
GSM8K Single-Agent Evaluation Pipeline with Majority Voting
Single path: Generator -> Answer (3 samples with majority voting)

Pipeline:
1. Generator: Question -> 3 Answers -> Majority Vote

Evaluates on GSM8K test set and saves results.
"""

import re
import json
import random
import time
import torch
from pathlib import Path
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from datasets import load_dataset
from openai import OpenAI
import requests
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

class GSM8KvLLMEvaluator:
    """Generic vLLM evaluator for any model component"""
    
    def __init__(self, 
                 base_url: str,
                 model_name: str,
                 component_name: str = "model",
                 api_key: str = "EMPTY"):
        """Initialize the vLLM evaluator with API client"""
        print(f"Connecting to vLLM server for {component_name} at: {base_url}")
        print(f"{component_name.title()} model: {model_name}")
        print("Model type: Base (using completions API)")
        
        self.base_url = base_url
        self.model_name = model_name
        self.component_name = component_name
        
        # Initialize OpenAI client for vLLM
        self.client = OpenAI(
            base_url=base_url,
            api_key=api_key,
        )
        
        # Test connection
        self._test_connection()
        
    def _test_connection(self):
        """Test connection to vLLM server"""
        max_retries = 5
        
        print(f"Testing {self.component_name} server connection...")
        for i in range(max_retries):
            try:
                response = requests.get(f"{self.base_url.replace('/v1', '')}/health", timeout=10)
                if response.status_code == 200:
                    print(f"✓ Successfully connected to {self.component_name} server")
                    return
            except requests.exceptions.RequestException:
                pass
            
            print(f"Attempt {i+1}/{max_retries}: Waiting for {self.component_name} server...")
            time.sleep(5)
        
        raise ConnectionError(f"Could not connect to {self.component_name} server at {self.base_url}.")
    
    def generate_text(self, prompt: str, max_tokens: int = 512, retries: int = 3, stop_tokens: List[str] = None) -> str:
        """Generate text using vLLM completions API with retries"""
        if stop_tokens is None:
            stop_tokens = ["\n\n", "\nQ:", "Question:", "You are", "Q:"]
            
        for attempt in range(retries):
            try:
                response = self.client.completions.create(
                    model=self.model_name,
                    prompt=prompt,
                    max_tokens=max_tokens,
                    temperature=0.2,
                    top_p=0.9,
                    stop=stop_tokens,
                )
                return response.choices[0].text.strip()
                
            except Exception as e:
                if attempt < retries - 1:
                    print(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                    time.sleep(1)
                    continue
                else:
                    print(f"API call failed after {retries} attempts: {str(e)}")
                    return ""


class GSM8KSingleAgentMajorityVoteEvaluator:
    """Single-agent evaluator with majority voting that uses generator via vLLM"""
    
    def __init__(self, generator_base_url: str, generator_model_name: str):
        """Initialize the single-agent evaluator with one vLLM client"""
        self.generator = GSM8KvLLMEvaluator(generator_base_url, generator_model_name, "generator")
    
    def create_generator_prompt(self, question: str) -> str:
        """Create prompt for generator model"""
        return f"Q: {question}\nA:"
    
    def generate_answer(self, question: str) -> str:
        """Generate answer for a question using generator"""
        prompt = self.create_generator_prompt(question)
        return self.generator.generate_text(prompt, stop_tokens=["\n\n", "\nQ:", "Question:"])
    
    def generate_multiple_answers(self, question: str, num_samples: int = 3) -> List[str]:
        """Generate multiple answers for a question using generator"""
        answers = []
        for _ in range(num_samples):
            answer = self.generate_answer(question)
            answers.append(answer)
        return answers
    
    def extract_answer_from_text(self, text: str) -> Optional[float]:
        """Extract numeric answer from generated text"""
        # Remove commas from numbers
        text = text.replace(",", "")
        
        # Look for "The answer is X" pattern
        answer_pattern = r"[Tt]he answer is:?\s*([+-]?\d+(?:\.\d+)?)"
        match = re.search(answer_pattern, text)
        if match:
            try:
                return float(match.group(1))
            except ValueError:
                pass
        
        # Look for #### pattern (unconditionally, as it's a common format)
        if "####" in text:
            after_hash = text.split("####")[-1].strip()
            match = re.search(r'([+-]?\d+(?:\.\d+)?)', after_hash)
            if match:
                try:
                    return float(match.group(1))
                except ValueError:
                    pass
        
        # Find last number after "answer"
        parts = re.split(r'answer', text, flags=re.IGNORECASE)
        if len(parts) > 1:
            last_part = parts[-1]
            numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', last_part)
            if numbers:
                try:
                    return float(numbers[0]) # Take the first number after "answer"
                except ValueError:
                    pass
        
        # Last resort: find any number in the last line
        lines = text.strip().split('\n')
        if lines:
            last_line = lines[-1]
            numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', last_line)
            if numbers:
                try:
                    return float(numbers[-1]) # Take the last number
                except ValueError:
                    pass
        
        return None
    
    def majority_vote(self, extracted_answers: List[Optional[float]]) -> Optional[float]:
        """
        Apply majority voting logic:
        - If all three answers are different, use the first answer
        - If two of the three answers are the same, use that majority answer
        - If all three are same, use that majority answer
        - Handle None values (failed extractions) appropriately
        """
        # Filter out None values
        valid_answers = [ans for ans in extracted_answers if ans is not None]
        
        if not valid_answers:
            return None  # All extractions failed
        
        # Count occurrences of each valid answer
        counts = Counter(valid_answers)
        
        # Find the most frequent answer(s)
        max_count = max(counts.values())
        most_frequent = [ans for ans, count in counts.items() if count == max_count]
        
        # If there's a clear majority (appears more than once)
        if max_count > 1:
            # There should be only one most frequent with count > 1 when we have 3 samples
            return most_frequent[0]
        
        # All valid answers appear exactly once (all different)
        # Return the first valid answer in original order
        for ans in extracted_answers:
            if ans is not None:
                return ans
        
        return None
    
    def extract_ground_truth(self, solution: str) -> float:
        """Extract ground truth answer from GSM8K solution"""
        parts = solution.split("####")
        if len(parts) >= 2:
            answer_str = parts[-1].strip().replace(",", "")
            match = re.search(r'([+-]?\d+(?:\.\d+)?)', answer_str)
            if match:
                return float(match.group())
        raise ValueError(f"Could not extract answer from solution: {solution}")
    
    def is_correct_answer(self, final_answer: Optional[float], ground_truth_solution: str) -> bool:
        """Check if final majority vote answer matches ground truth"""
        try:
            if final_answer is None:
                return False
            
            ground_truth = self.extract_ground_truth(ground_truth_solution)
            return abs(final_answer - ground_truth) < 0.01
        except Exception:
            return False


def run_single_evaluation(seed: int, run_num: int, evaluator: GSM8KSingleAgentMajorityVoteEvaluator, test_dataset) -> Dict:
    """Run a single evaluation with the given seed"""
    num_questions = 100
    
    # Set random seed for this run
    random.seed(seed)
    torch.manual_seed(seed)
    
    # Select random subset of questions using the seed
    if len(test_dataset) < num_questions:
        questions_to_process_indices = list(range(len(test_dataset)))
    else:
        questions_to_process_indices = random.sample(range(len(test_dataset)), num_questions)
    
    questions_to_process = [test_dataset[i] for i in questions_to_process_indices]
    actual_num_questions = len(questions_to_process)
    
    print(f"  Run {run_num}: Selected {actual_num_questions} random samples for evaluation.")
    
    # Results storage (simplified)
    questions_data = []
    correct_count = 0
    
    # Iterate over the randomly selected samples with progress bar
    for idx, (original_idx, item) in enumerate(tqdm(zip(questions_to_process_indices, questions_to_process), 
                                                   total=actual_num_questions, 
                                                   desc=f'    Processing questions')):
        question = item['question']
        ground_truth_solution = item['answer']
        
        # Generate 3 answers for majority voting
        generator_answers = evaluator.generate_multiple_answers(question, num_samples=3)
        
        # Extract numerical answers from all 3 responses
        extracted_answers = [evaluator.extract_answer_from_text(answer) for answer in generator_answers]
        
        # Apply majority voting
        majority_vote_answer = evaluator.majority_vote(extracted_answers)
        
        # Check if majority vote answer is correct
        is_correct = evaluator.is_correct_answer(majority_vote_answer, ground_truth_solution)
        if is_correct:
            correct_count += 1
        
        # Store detailed result
        question_result = {
            'question_id': idx,
            'question': question,
            'ground_truth_solution': ground_truth_solution,
            'generator_answers': generator_answers,
            'extracted_answers': extracted_answers,
            'majority_vote_answer': majority_vote_answer
        }
        
        questions_data.append(question_result)
    
    # Calculate accuracy
    accuracy = correct_count / actual_num_questions if actual_num_questions > 0 else 0.0
    
    # Create simplified results
    results = {
        'questions': questions_data,
        'accuracy': accuracy,
        'seed': seed
    }
    
    print(f"  Run {run_num}: Majority vote accuracy: {accuracy:.3f} ({correct_count}/{actual_num_questions})")
    
    return results


def main():
    """Main evaluation function"""
    
    # Configuration
    generator_base_url = "http://localhost:8001/v1"
    generator_model_name = "generator_dpo"
    
    seeds = [42, 43, 44, 45]
    runs_per_seed = 5
    
    print("="*80)
    print("GSM8K Single-Agent Majority Vote Evaluation")
    print(f"Generator: vLLM-based ({generator_base_url}, {generator_model_name})")
    print(f"Sampling strategy: 3 samples per question with majority voting")
    print(f"Seeds: {seeds}")
    print(f"Runs per seed: {runs_per_seed}")
    print(f"Questions per run: 100")
    print("="*80)

    # Load GSM8K test dataset
    print("\nLoading GSM8K test dataset...")
    dataset = load_dataset("openai/gsm8k", "main", trust_remote_code=True)
    test_dataset = dataset["test"]
    print(f"Total test samples: {len(test_dataset)}")
    
    # Initialize evaluator once
    print("\nInitializing evaluator...")
    evaluator = GSM8KSingleAgentMajorityVoteEvaluator(
        generator_base_url=generator_base_url,
        generator_model_name=generator_model_name
    )
    
    # Create output directory
    output_dir = Path('dpo_singleagent_mv')
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\nStarting evaluation...")
    start_time = time.time()
    
    # Run evaluations
    for seed in seeds:
        print(f"\nSeed {seed}:")
        for run_num in range(1, runs_per_seed + 1):
            print(f"  Starting run {run_num}...")
            
            results = run_single_evaluation(seed, run_num, evaluator, test_dataset)
            
            # Save results
            output_file = output_dir / f'seed_{seed}_run_{run_num}.json'
            with output_file.open('w') as f:
                json.dump(results, f, indent=2)
            
            print(f"  Run {run_num} completed. Results saved to: {output_file}")
    
    total_time = time.time() - start_time
    
    print("\n" + "="*80)
    print("ALL EVALUATIONS COMPLETED")
    print("="*80)
    print(f"Total files generated: {len(seeds) * runs_per_seed}")
    print(f"Total time: {total_time:.2f} seconds")
    print(f"Results saved to: {output_dir}/")
    print("="*80)


if __name__ == "__main__":
    main()