#!/usr/bin/env python3
"""
GSM8K Multi-Agent Synthetic Data Generation using vLLM
Creates training data for Generator, Verifier, and Refinement models

Pipeline:
1. Generator: Question -> Answer (3 outputs)
2. Verifier: Question + Answer -> Critique (3 outputs per answer) 
3. Refinement: Question + Answer + Critique -> Refined Answer (3 outputs per critique)

Uses downstream performance to label upstream components as correct/incorrect.
Uses base models only with completions API.
"""

import re
import json
import random
import time
import logging
from pathlib import Path
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional
from datasets import load_dataset
from collections import defaultdict
from openai import OpenAI
import requests
import warnings
warnings.filterwarnings('ignore')

# 8-shot Chain-of-Thought examples for base models
GSM8K_COT_EXAMPLES = """
"""

class GSM8KMultiAgentGenerator:
    """Multi-agent synthetic data generator using vLLM backend for base models only"""
    
    def __init__(self, 
                 base_url: str = "http://localhost:8000/v1",
                 model_name: str = "qwen-2.5-1.5b",
                 api_key: str = "EMPTY"):
        """Initialize the multi-agent generator with vLLM API client"""
        print(f"Connecting to vLLM server at: {base_url}")
        print(f"Model: {model_name}")
        print("Model type: Base (using completions API)")
        
        self.base_url = base_url
        self.model_name = model_name
        
        # Initialize OpenAI client for vLLM
        self.client = OpenAI(
            base_url=base_url,
            api_key=api_key,
        )
        
        # Test connection
        self._test_connection()
        
    def _test_connection(self):
        """Test connection to vLLM server"""
        max_retries = 5
        for i in range(max_retries):
            try:
                response = requests.get(f"{self.base_url.replace('/v1', '')}/health", timeout=10)
                if response.status_code == 200:
                    print("✓ Successfully connected to vLLM server")
                    return
            except requests.exceptions.RequestException:
                pass
            
            print(f"Attempt {i+1}/{max_retries}: Waiting for vLLM server...")
            time.sleep(5)
        
        raise ConnectionError("Could not connect to vLLM server. Make sure it's running.")
    
    def generate_text(self, prompt: str, max_tokens: int = 512, retries: int = 3, stop_tokens: List[str] = None) -> str:
        """Generate text using vLLM completions API with retries"""
        if stop_tokens is None:
            stop_tokens = ["\n\n", "\nQ:", "Question:", "You are", "Q:"]
            
        for attempt in range(retries):
            try:
                response = self.client.completions.create(
                    model=self.model_name,
                    prompt=prompt,
                    max_tokens=max_tokens,
                    temperature=0.2,
                    top_p=0.9,
                    # repetition_penalty=1.05,
                    stop=stop_tokens,
                )
                return response.choices[0].text.strip()
                
            except Exception as e:
                if attempt < retries - 1:
                    print(f"API call failed (attempt {attempt + 1}/{retries}): {str(e)}")
                    time.sleep(1)
                    continue
                else:
                    print(f"API call failed after {retries} attempts: {str(e)}")
                    return ""
    
    def create_generator_prompt(self, question: str) -> str:
        """Create prompt for generator model - same as original GSM8K evaluator"""
        return GSM8K_COT_EXAMPLES + f"Q: {question}\nA:"
    
    def create_verifier_prompt(self, question: str, answer: str) -> str:
        """Create prompt for verifier model"""
        return f"""You are a verifier and need to properly check this solution.

Question: {question}

Answer: {answer}

Verification:"""
    
    def create_refinement_prompt(self, question: str, answer: str, critique: str) -> str:
        """Create prompt for refinement model"""
        return f"""You are a refinement model and get a question, generated answer and an associated critique to improve it. You need to refine the answer to ensure it is correct.

Question: {question}

Generated Answer: {answer}

Critique: {critique}

Refined Answer:"""
    
    def generate_answer(self, question: str) -> str:
        """Generate answer for a question - same as original GSM8K evaluator"""
        prompt = self.create_generator_prompt(question)
        # Use same stop tokens as original evaluator for consistency
        return self.generate_text(prompt, stop_tokens=["\n\n", "\nQ:", "Question:"])
    
    def generate_verification(self, question: str, answer: str) -> str:
        """Generate verification for a question-answer pair"""
        prompt = self.create_verifier_prompt(question, answer)
        return self.generate_text(prompt, stop_tokens=["\n\n", "Question:", "Q:"])
    
    def generate_refinement(self, question: str, answer: str, critique: str) -> str:
        """Generate refined answer given question, answer, and critique"""
        prompt = self.create_refinement_prompt(question, answer, critique)
        return self.generate_text(prompt, stop_tokens=["\n\n", "Question:", "Q:"])
    
    def extract_answer_from_text(self, text: str) -> Optional[float]:
        """Extract numeric answer from generated text - same as original evaluator"""
        # Remove commas from numbers
        text = text.replace(",", "")
        
        # Look for "The answer is X" pattern
        answer_pattern = r"[Tt]he answer is:?\s*([+-]?\d+(?:\.\d+)?)"
        match = re.search(answer_pattern, text)
        if match:
            try:
                return float(match.group(1))
            except ValueError:
                pass
        
        # Look for #### pattern
        if "####" in text:
            after_hash = text.split("####")[-1].strip()
            match = re.search(r'([+-]?\d+(?:\.\d+)?)', after_hash)
            if match:
                try:
                    return float(match.group(1))
                except ValueError:
                    pass
        
        # Find last number after "answer"
        parts = re.split(r'answer', text, flags=re.IGNORECASE)
        if len(parts) > 1:
            last_part = parts[-1]
            numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', last_part)
            if numbers:
                try:
                    return float(numbers[0])
                except ValueError:
                    pass
        
        # Last resort: find any number in the last line
        lines = text.strip().split('\n')
        if lines:
            last_line = lines[-1]
            numbers = re.findall(r'([+-]?\d+(?:\.\d+)?)', last_line)
            if numbers:
                try:
                    return float(numbers[-1])
                except ValueError:
                    pass
        
        return None
    
    def extract_ground_truth(self, solution: str) -> float:
        """Extract ground truth answer from GSM8K solution - same as original evaluator"""
        parts = solution.split("####")
        if len(parts) >= 2:
            answer_str = parts[-1].strip().replace(",", "")
            match = re.search(r'([+-]?\d+(?:\.\d+)?)', answer_str)
            if match:
                return float(match.group())
        raise ValueError(f"Could not extract answer from solution: {solution}")
    
    def is_correct_answer(self, generated_text: str, ground_truth_solution: str) -> bool:
        """Check if generated answer matches ground truth - same as original evaluator"""
        try:
            predicted = self.extract_answer_from_text(generated_text)
            ground_truth = self.extract_ground_truth(ground_truth_solution)
            
            if predicted is None:
                return False
            
            return abs(predicted - ground_truth) < 0.01
        except Exception:
            return False


def setup_logging():
    """Set up logging for the data generation process"""
    log_dir = Path('gsm8k_multiagent_logs_0to1000')
    log_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f'multiagent_generation_{timestamp}.log'
    
    logging.basicConfig(
        level=logging.INFO,
        filename=log_file,
        filemode='w',
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    logging.info(f'Logging to {log_file}')
    return log_file


def save_data_to_json(file_name: str, data_to_save: List[Dict]):
    """Save data to JSON file, appending to existing data if file exists"""
    if not data_to_save:
        logging.info(f"No new data to save for {file_name}")
        return
    
    file_path = Path(file_name)
    file_path.parent.mkdir(parents=True, exist_ok=True)
    
    logging.info(f"Attempting to save {len(data_to_save)} items to {file_path}")
    
    if file_path.exists():
        with file_path.open('r') as f:
            try:
                existing_data = json.load(f)
            except json.JSONDecodeError:
                existing_data = []
        existing_data.extend(data_to_save)
    else:
        existing_data = data_to_save
    
    with file_path.open('w') as f:
        json.dump(existing_data, f, indent=2)
    
    logging.info(f"Successfully saved {len(existing_data)} total items to {file_path}")


def main():
    """Main data generation function"""
    
    # Configuration
    vllm_base_url = "http://localhost:8000/v1"
    vllm_model_name = "qwen-2.5-1.5b"
    num_questions = 5000  # Number of questions to process
    n_outputs = 3  # Number of outputs per stage
    batch_size = 10  # Save every 100 questions
    seed = 42
    
    # Set up logging
    log_file = setup_logging()
    
    # Set random seed
    random.seed(seed)
    
    print("="*80)
    print("GSM8K Multi-Agent Synthetic Data Generation")
    print("Model type: Base (8-shot CoT)")
    print(f"vLLM Server: {vllm_base_url}")
    print(f"vLLM Model Name: {vllm_model_name}")
    print(f"Number of questions: {num_questions}")
    print(f"Outputs per stage: {n_outputs}")
    print(f"Batch size: {batch_size}")
    print(f"Log file: {log_file}")
    print("="*80)
    
    # Load GSM8K dataset
    print("\nLoading GSM8K training dataset...")
    dataset = load_dataset("openai/gsm8k", "main", trust_remote_code=True)
    train_dataset = dataset["train"]
    print(f"Total training samples: {len(train_dataset)}")
    
    # Select subset of questions
    start_idx = 0
    end_idx = 5000
    questions_to_process = train_dataset.select(range(start_idx, min(end_idx, len(train_dataset))))
    
    # Initialize generator
    generator = GSM8KMultiAgentGenerator(
        base_url=vllm_base_url,
        model_name=vllm_model_name
    )
    
    # Data storage
    data_dir = Path('gsm8k_multiagent_data_0to5000')
    data_dir.mkdir(parents=True, exist_ok=True)
    
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    generator_data = []
    verifier_data = []
    refinement_data = []
    
    question_counter = 0
    batch_number = 1
    
    print(f"\nStarting data generation...")
    start_time = time.time()
    
    for item in tqdm(questions_to_process, desc='Processing questions'):
        question_counter += 1
        question = item['question']
        ground_truth_solution = item['answer']
        
        logging.info(f"Processing question {question_counter}: {question}")
        
        # Step 1: Generator outputs
        generator_outputs = []
        for i in range(n_outputs):
            answer = generator.generate_answer(question)
            if not answer:
                logging.warning(f"Empty answer from generator for question {question_counter}")
                continue
            
            gen_output = {
                'answer': answer,
                'verifier_outputs': []
            }
            generator_outputs.append(gen_output)
        
        if not generator_outputs:
            logging.warning(f"No valid generator outputs for question {question_counter}")
            continue
        
        # Step 2: Verifier outputs
        for gen_output in generator_outputs:
            answer = gen_output['answer']
            verifier_outputs = []
            
            for _ in range(n_outputs):
                verification = generator.generate_verification(question, answer)
                if not verification:
                    logging.warning(f"Empty verification for question {question_counter}")
                    continue
                
                verifier_output = {
                    'verification': verification,
                    'refinement_outputs': []
                }
                verifier_outputs.append(verifier_output)
            
            gen_output['verifier_outputs'] = verifier_outputs
        
        # Step 3: Refinement outputs
        for gen_output in generator_outputs:
            answer = gen_output['answer']
            for verifier_output in gen_output['verifier_outputs']:
                verification = verifier_output['verification']
                refinement_outputs = []
                
                for _ in range(n_outputs):
                    refinement = generator.generate_refinement(question, answer, verification)
                    if not refinement:
                        logging.warning(f"Empty refinement for question {question_counter}")
                        continue
                    
                    # Check if refinement is correct
                    is_correct = generator.is_correct_answer(refinement, ground_truth_solution)
                    
                    refinement_outputs.append({
                        'refinement': refinement,
                        'is_correct': is_correct
                    })
                
                verifier_output['refinement_outputs'] = refinement_outputs
        
        # Step 4: Classify outputs based on downstream performance
        
        # Generator classification
        gen_correct_outputs = []
        gen_incorrect_outputs = []
        
        for gen_output in generator_outputs:
            answer = gen_output['answer']
            total_refinements = 0
            correct_refinements = 0
            
            for verifier_output in gen_output['verifier_outputs']:
                for refinement_output in verifier_output['refinement_outputs']:
                    total_refinements += 1
                    if refinement_output['is_correct']:
                        correct_refinements += 1
            
            # Classify based on >50% downstream success
            if total_refinements > 0 and correct_refinements > total_refinements / 2:
                gen_correct_outputs.append(answer)
            else:
                gen_incorrect_outputs.append(answer)
        
        # Create generator training data
        min_len = min(len(gen_correct_outputs), len(gen_incorrect_outputs))
        for i in range(min_len):
            generator_prompt = generator.create_generator_prompt(question)
            generator_data.append({
                'input': generator_prompt,
                'correct_output': gen_correct_outputs[i],
                'incorrect_output': gen_incorrect_outputs[i]
            })
        
        # Verifier classification
        verifier_correct_outputs = defaultdict(list)
        verifier_incorrect_outputs = defaultdict(list)
        
        for gen_output in generator_outputs:
            answer = gen_output['answer']
            for verifier_output in gen_output['verifier_outputs']:
                verification = verifier_output['verification']
                total_refinements = len(verifier_output['refinement_outputs'])
                correct_refinements = sum(1 for ro in verifier_output['refinement_outputs'] if ro['is_correct'])
                
                if total_refinements > 0 and correct_refinements > total_refinements / 2:
                    verifier_correct_outputs[answer].append(verification)
                else:
                    verifier_incorrect_outputs[answer].append(verification)
        
        # Create verifier training data
        for answer in verifier_correct_outputs:
            correct_verifications = verifier_correct_outputs[answer]
            incorrect_verifications = verifier_incorrect_outputs.get(answer, [])
            min_len = min(len(correct_verifications), len(incorrect_verifications))
            
            for i in range(min_len):
                verifier_prompt = generator.create_verifier_prompt(question, answer)
                verifier_data.append({
                    'input': verifier_prompt,
                    'correct_output': correct_verifications[i],
                    'incorrect_output': incorrect_verifications[i]
                })
        
        # Refinement classification
        for gen_output in generator_outputs:
            answer = gen_output['answer']
            for verifier_output in gen_output['verifier_outputs']:
                verification = verifier_output['verification']
                correct_refinements = [ro['refinement'] for ro in verifier_output['refinement_outputs'] if ro['is_correct']]
                incorrect_refinements = [ro['refinement'] for ro in verifier_output['refinement_outputs'] if not ro['is_correct']]
                
                min_len = min(len(correct_refinements), len(incorrect_refinements))
                for i in range(min_len):
                    refinement_prompt = generator.create_refinement_prompt(question, answer, verification)
                    refinement_data.append({
                        'input': refinement_prompt,
                        'correct_output': correct_refinements[i],
                        'incorrect_output': incorrect_refinements[i]
                    })
        
        # Save data every batch_size questions
        if question_counter % batch_size == 0:
            # Create filenames with batch info
            batch_label = f"{batch_number}_{batch_size * batch_number}"
            current_timestamp = time.strftime("%Y%m%d_%H%M%S")
            
            generator_filename = data_dir / f'generator_data_{batch_label}_{current_timestamp}.json'
            verifier_filename = data_dir / f'verifier_data_{batch_label}_{current_timestamp}.json'
            refinement_filename = data_dir / f'refinement_data_{batch_label}_{current_timestamp}.json'
            
            # Save data
            save_data_to_json(str(generator_filename), generator_data)
            save_data_to_json(str(verifier_filename), verifier_data)
            save_data_to_json(str(refinement_filename), refinement_data)
            
            # Print progress
            elapsed_time = time.time() - start_time
            print(f"\nBatch {batch_number} completed after processing {question_counter} questions")
            print(f"Generator pairs: {len(generator_data)}")
            print(f"Verifier pairs: {len(verifier_data)}")
            print(f"Refinement pairs: {len(refinement_data)}")
            print(f"Time elapsed: {elapsed_time:.2f} seconds")
            print(f"Average time per question: {elapsed_time / question_counter:.2f} seconds")
            
            # Clear data for next batch
            generator_data = []
            verifier_data = []
            refinement_data = []
            batch_number += 1
    
    # Save any remaining data
    if generator_data or verifier_data or refinement_data:
        batch_label = f"final_{question_counter}"
        current_timestamp = time.strftime("%Y%m%d_%H%M%S")
        
        generator_filename = data_dir / f'generator_data_{batch_label}_{current_timestamp}.json'
        verifier_filename = data_dir / f'verifier_data_{batch_label}_{current_timestamp}.json'
        refinement_filename = data_dir / f'refinement_data_{batch_label}_{current_timestamp}.json'
        
        save_data_to_json(str(generator_filename), generator_data)
        save_data_to_json(str(verifier_filename), verifier_data)
        save_data_to_json(str(refinement_filename), refinement_data)
        
        print(f"\nFinal batch saved with remaining data")
    
    total_time = time.time() - start_time
    print("\n" + "="*80)
    print("DATA GENERATION COMPLETED")
    print("="*80)
    print(f"Total questions processed: {question_counter}")
    print(f"Total time: {total_time:.2f} seconds")
    print(f"Average time per question: {total_time / question_counter:.2f} seconds")
    print(f"Data saved to: {data_dir}")
    print(f"Log file: {log_file}")
    print("="*80)


if __name__ == "__main__":
    main()