#!/usr/bin/env python
"""
GPQA Diamond Parallel Reasoning Experiment Template
ICLR 2026 Submission Code

This script implements parallel test-time scaling for GPQA Diamond problems.
Generates multiple independent reasoning chains and uses majority voting.
"""

import os
import requests
import json
import re
from collections import Counter
import time
from datetime import datetime
import pandas as pd
from datasets import load_dataset
import random
import math
import concurrent.futures

# --- Main Experiment Configuration ---
MODEL_TO_TEST = "openai/gpt-oss-120b"  # Replace with desired model
MAX_CHAINS = 3  # Number of parallel reasoning chains (3, 6, or 9)

# --- Configuration ---
API_KEY = "YOUR_OPENROUTER_KEY"  # Replace with your OpenRouter API key
HUGGING_FACE_DATASET_ID = "Idavidrein/gpqa"  # GPQA Diamond dataset
RATE_LIMIT_DELAY = 0.5
MAX_RETRIES = 5
BACKOFF_MULTIPLIER = 2
MAX_TOKENS_PER_STEP = 4096

# --- OpenRouter Client Class ---
class OpenRouterClient:
    """OpenRouter API client for model inference."""
    
    def __init__(self, api_key=None):
        self.api_key = api_key or API_KEY
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://github.com/anonymous-repo",
            "X-Title": "GPQA Diamond Parallel Reasoning",
        }
    
    def call_model(self, model_name, messages, max_tokens, temperature=0.7, top_p=0.8):
        """Constructs and sends the API request payload."""
        payload = {
            "model": model_name,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "stream": False,
            "reasoning": {"effort": "high"},
            "provider": {"order": ["deepinfra/fp4"], "require_parameters": True, "allow_fallbacks": False}
        }
        return requests.post(self.url, headers=self.headers, json=payload, timeout=240)

# --- Metrics Tracking ---
class MetricsTracker:
    def __init__(self, model_name, strategy_name):
        self.model_name = model_name
        self.strategy_name = strategy_name
        self.total_questions_processed = 0
        self.total_questions_planned = 0
        self.correct_counts = 0
        self.total_tokens_used = {"input": 0, "output": 0, "total": 0}
        self.total_api_calls = 0
        self.total_failed_questions = 0
        self.start_time = None
        
    def start_experiment(self, total_questions):
        self.start_time = time.time()
        self.total_questions_planned = total_questions
        self.print_header()
        
    def log_api_call(self, usage_dict):
        self.total_api_calls += 1
        if usage_dict:
            self.total_tokens_used["input"] += usage_dict.get("prompt_tokens", 0)
            self.total_tokens_used["output"] += usage_dict.get("completion_tokens", 0)
            self.total_tokens_used["total"] += usage_dict.get("total_tokens", 0)
            
    def log_failed_question(self):
        self.total_failed_questions += 1
        
    def complete_question(self, evaluation_result):
        self.total_questions_processed += 1
        if evaluation_result == "Correct":
            self.correct_counts += 1
        
    def print_header(self):
        print("\\n" + "="*100)
        print(f"🚀 STARTING PARALLEL EXPERIMENT: {self.strategy_name}")
        print(f"🤖 Model: {self.model_name} | Max Chains: {MAX_CHAINS}")
        print("="*100)
        
    def print_running_metrics(self):
        processed = self.total_questions_processed
        if processed == 0: return
        elapsed_time = time.time() - self.start_time
        avg_time = elapsed_time / processed
        eta_minutes = (self.total_questions_planned - processed) * avg_time / 60
        acc = round(self.correct_counts / processed * 100, 2) if processed > 0 else 0
        
        print("\\n" + f"🔄 LIVE METRICS UPDATE ({processed}/{self.total_questions_planned}) for {self.strategy_name}" + "="*20)
        print(f"⏱️  Avg Time/Q: {avg_time:.1f}s | Total Time: {elapsed_time/60:.1f}m | ETA: {eta_minutes:.1f}m")
        print(f"💰 Total Tokens: {self.total_tokens_used['total']:,} | API Calls: {self.total_api_calls:,}")
        print(f"📉 Failed Questions (no answer): {self.total_failed_questions}")
        print(f"   🎯 Simple Majority Accuracy: {acc}% ({self.correct_counts}/{processed})")
        print("="*100)

# --- Initialize Global Client ---
client = OpenRouterClient()

# --- Utility Functions ---
def load_gpqa_dataset_from_hf(dataset_id: str):
    """Load GPQA dataset from Hugging Face."""
    try:
        print(f"📚 Loading dataset '{dataset_id}' from Hugging Face...")
        dataset = load_dataset(dataset_id, "gpqa_diamond")
        df = dataset['train'].to_pandas()
        print("✅ Dataset loaded successfully.")
        return df
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return pd.DataFrame()

def call_model_with_rate_limiting(model_name: str, messages: list, max_tokens: int, temperature: float) -> dict | None:
    """Call model with retry logic and rate limiting."""
    current_delay = RATE_LIMIT_DELAY
    for attempt in range(MAX_RETRIES):
        try:
            time.sleep(current_delay)
            response = client.call_model(model_name, messages, max_tokens, temperature)
            response.raise_for_status()
            return response.json()

        except requests.exceptions.JSONDecodeError:
            print(f"    ❌ JSONDecodeError on attempt {attempt+1}/{MAX_RETRIES}.")
            print(f"    Raw Response Text (Status {response.status_code}):\\n---")
            print(response.text)
            print("---\\n    Retrying...")
            current_delay *= BACKOFF_MULTIPLIER
            continue
        except requests.exceptions.HTTPError as e:
            print(f"❌ HTTP Error on attempt {attempt+1}/{MAX_RETRIES}: {e}\\n    Response Body: {e.response.text}")
            current_delay *= BACKOFF_MULTIPLIER
        except requests.exceptions.RequestException as e:
            print(f"❌ Request Error on attempt {attempt+1}/{MAX_RETRIES}: {e}")
            current_delay *= BACKOFF_MULTIPLIER
    print(f"❌ Error: Max retries ({MAX_RETRIES}) exceeded for model {model_name}.")
    return None

def extract_answer(raw_response: dict) -> str | None:
    """Extract the final answer from model response."""
    if not raw_response or not raw_response.get('choices'): return None
    response_text = raw_response['choices'][0].get('message', {}).get('content', '')
    if not response_text: return None
    
    # Look for various answer patterns
    answer_patterns = [
        r'(?:answer|final answer).*?(?:is|:)\\s*\\(?([A-D])\\)?',
        r'\\(?([A-D])\\)?\\s*(?:is|$)',
        r'(?:option|choice)\\s*\\(?([A-D])\\)?',
        r'\\bthe\\s+answer\\s+is\\s+\\(?([A-D])\\)?',
        r'\\bcorrect\\s+answer\\s+is\\s+\\(?([A-D])\\)?'
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, response_text, re.IGNORECASE)
        if match:
            return match.group(1).upper()
    
    return None

# --- Core Strategy Function ---
def run_parallel_strategy(model_name: str, question_data: dict, num_chains: int, metrics: MetricsTracker):
    """Execute parallel reasoning strategy."""
    system_prompt = """You are an expert in graduate-level physics, chemistry, and biology. Your task is to solve challenging scientific questions that require deep understanding and careful reasoning.

Follow these instructions precisely:
1. **Analyze the Question:** Read the question carefully and identify the key concepts, principles, and information provided.
2. **Think Step-by-Step:** Work through the problem systematically, explaining your reasoning at each step.
3. **Consider All Options:** Evaluate each multiple choice option carefully.
4. **Final Answer:** Provide your final answer as a single letter (A, B, C, or D).

Be thorough in your analysis and double-check your reasoning before providing the final answer."""
    
    question_text = question_data['Question']
    choices = f"A) {question_data['Correct Answer']}\\nB) {question_data['Incorrect Answer 1']}\\nC) {question_data['Incorrect Answer 2']}\\nD) {question_data['Incorrect Answer 3']}"
    
    full_question = f"{question_text}\\n\\n{choices}"
    
    initial_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Please solve the following scientific question:\\n\\n{full_question}"}
    ]
    
    api_responses = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_chains) as executor:
        future_to_chain = {
            executor.submit(call_model_with_rate_limiting, model_name, initial_messages, MAX_TOKENS_PER_STEP, temperature=0.4 + (i * 0.05)): i 
            for i in range(num_chains)
        }
        
        for future in concurrent.futures.as_completed(future_to_chain):
            chain_index = future_to_chain[future]
            try:
                response = future.result()
                print(f"\\n  📍 Parallel Chain {chain_index+1}/{num_chains} completed.")
                if response:
                    api_responses.append(response)
                    metrics.log_api_call(response.get('usage'))
                    
                    chain_answer = extract_answer(response)
                    usage = response.get('usage', {})
                    token_str = f"In:{usage.get('prompt_tokens',0)}/Out:{usage.get('completion_tokens',0)}/Total:{usage.get('total_tokens',0)}"
                    print(f"    ➡️ Chain {chain_index+1} Answer: {chain_answer} | Tokens Used: {token_str}")
                    
            except Exception as exc:
                print(f"  ❌ Parallel chain {chain_index+1} generated an exception: {exc}")
    return api_responses

# --- Analysis Function ---
def analyze_and_evaluate(api_responses: list, ground_truth: str):
    """Analyze responses and evaluate using majority voting."""
    if not api_responses: return {"error": "No API responses."}
    
    extracted_answers = [extract_answer(resp) for resp in api_responses if resp is not None]
    valid_answers = [ans for ans in extracted_answers if ans is not None]

    if not valid_answers:
        return {'voted_answer': "N/A", 'evaluation': "Incorrect", 'has_valid_answer': False}
    
    majority_vote = Counter(valid_answers).most_common(1)[0][0]
    evaluation = "Correct" if majority_vote == ground_truth else "Incorrect"
    
    print(f"\\n🎯 QUESTION RESULTS (Ground Truth: {ground_truth})")
    status = "✅" if evaluation == "Correct" else "❌"
    print(f"   {status} Simple Majority Vote: {majority_vote} ({evaluation})")
    print(f"   (Individual answers from chains: {valid_answers})")

    return {'voted_answer': majority_vote, 'evaluation': evaluation, 'has_valid_answer': True}

# --- Main Execution Block ---
if __name__ == "__main__":
    test_df = load_gpqa_dataset_from_hf(HUGGING_FACE_DATASET_ID)
    if test_df.empty: exit()

    strategy_name = f"{MODEL_TO_TEST.split('/')[-1]}_parallel_{MAX_CHAINS}chains"
    metrics = MetricsTracker(MODEL_TO_TEST, strategy_name)
    metrics.start_experiment(len(test_df))
    all_question_results = []

    for index, row in test_df.iterrows():
        question_id = index  # GPQA doesn't have explicit IDs
        ground_truth_answer = "A"  # Correct answer is always option A in this dataset format
        
        print("\\n" + "🔸" * 50 + f"\\n📝 Processing Question {question_id} ({index + 1}/{len(test_df)})")
        print(f"   Correct Answer is: {ground_truth_answer}")
        
        api_log = run_parallel_strategy(MODEL_TO_TEST, row, MAX_CHAINS, metrics)
        analysis = analyze_and_evaluate(api_log, ground_truth_answer)
        
        if not analysis.get('has_valid_answer'):
            metrics.log_failed_question()
            
        metrics.complete_question(analysis.get('evaluation', 'Incorrect'))
        metrics.print_running_metrics()
        
        all_question_results.append({ "question_id": question_id, "analysis_results": analysis, "api_log": api_log })
    
    # --- Final Report ---
    processed_count = metrics.total_questions_processed
    final_accuracy = round(metrics.correct_counts / processed_count * 100, 2) if processed_count > 0 else 0
    
    final_report = {
        "experiment_summary": { 
            "model": MODEL_TO_TEST, 
            "strategy": "parallel", 
            "max_chains": MAX_CHAINS, 
            "dataset": HUGGING_FACE_DATASET_ID, 
            "total_questions_processed": processed_count, 
            "total_failed_questions": metrics.total_failed_questions, 
            "total_api_calls": metrics.total_api_calls, 
            "total_tokens_used": metrics.total_tokens_used, 
            "average_tokens_per_question": metrics.total_tokens_used['total'] / processed_count if processed_count > 0 else 0, 
            "average_time_per_question_sec": (time.time() - metrics.start_time) / processed_count if processed_count > 0 else 0, 
        },
        "final_accuracy_percent": {"simple_majority": final_accuracy},
        "detailed_results_per_question": all_question_results
    }
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"GPQA_Parallel_{MAX_CHAINS}chains_{strategy_name}_{timestamp}.json"
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(final_report, f, indent=2, ensure_ascii=False)
        
    print(f"\\n🏁🏁🏁 PARALLEL EXPERIMENT COMPLETE 🏁🏁🏁\\n📁 Full report saved to: {filename}")
    print("\\n--- FINAL SUMMARY ---")
    print(json.dumps(final_report["experiment_summary"], indent=2))