#!/usr/bin/env python
"""
AIME 2024 Parallel Reasoning Experiment Template
ICLR 2026 Submission Code

This script implements parallel test-time scaling for AIME 2024 problems.
Generates multiple independent reasoning chains and uses majority voting.
"""

import os
import requests
import json
import re
from collections import Counter
import time
from datetime import datetime
import pandas as pd
from datasets import load_dataset
import random
import math
import concurrent.futures

# --- Main Experiment Configuration ---
MODEL_TO_TEST = "openai/gpt-oss-120b"  # Replace with desired model
MAX_CHAINS = 3  # Number of parallel reasoning chains (3, 6, or 9)

# --- Configuration ---
API_KEY = "YOUR_OPENROUTER_KEY"  # Replace with your OpenRouter API key
HUGGING_FACE_DATASET_ID = "Maxwell-Jia/AIME_2024"
RATE_LIMIT_DELAY = 0.5
MAX_RETRIES = 5
BACKOFF_MULTIPLIER = 2
MAX_TOKENS_PER_STEP = 4096

# --- OpenRouter Client Class ---
class OpenRouterClient:
    """OpenRouter API client for model inference."""
    
    def __init__(self, api_key=None):
        self.api_key = api_key or API_KEY
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://github.com/anonymous-repo",
            "X-Title": "AIME Parallel Reasoning",
        }
    
    def call_model(self, model_name, messages, max_tokens, temperature=0.7, top_p=0.8):
        """Constructs and sends the API request payload."""
        payload = {
            "model": model_name,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "stream": False,
            "reasoning": {"effort": "high"},
            "provider": {"order": ["deepinfra/fp4"], "require_parameters": True, "allow_fallbacks": False}
        }
        return requests.post(self.url, headers=self.headers, json=payload, timeout=240)

# --- Metrics Tracking ---
class MetricsTracker:
    def __init__(self, model_name, strategy_name):
        self.model_name = model_name
        self.strategy_name = strategy_name
        self.total_questions_processed = 0
        self.total_questions_planned = 0
        self.correct_counts = 0
        self.total_tokens_used = {"input": 0, "output": 0, "total": 0}
        self.total_api_calls = 0
        self.total_failed_questions = 0
        self.start_time = None
        
    def start_experiment(self, total_questions):
        self.start_time = time.time()
        self.total_questions_planned = total_questions
        self.print_header()
        
    def log_api_call(self, usage_dict):
        self.total_api_calls += 1
        if usage_dict:
            self.total_tokens_used["input"] += usage_dict.get("prompt_tokens", 0)
            self.total_tokens_used["output"] += usage_dict.get("completion_tokens", 0)
            self.total_tokens_used["total"] += usage_dict.get("total_tokens", 0)
            
    def log_failed_question(self):
        self.total_failed_questions += 1
        
    def complete_question(self, evaluation_result):
        self.total_questions_processed += 1
        if evaluation_result == "Correct":
            self.correct_counts += 1
        
    def print_header(self):
        print("\\n" + "="*100)
        print(f"🚀 STARTING PARALLEL EXPERIMENT: {self.strategy_name}")
        print(f"🤖 Model: {self.model_name} | Max Chains: {MAX_CHAINS}")
        print("="*100)
        
    def print_running_metrics(self):
        processed = self.total_questions_processed
        if processed == 0: return
        elapsed_time = time.time() - self.start_time
        avg_time = elapsed_time / processed
        eta_minutes = (self.total_questions_planned - processed) * avg_time / 60
        acc = round(self.correct_counts / processed * 100, 2) if processed > 0 else 0
        
        print("\\n" + f"🔄 LIVE METRICS UPDATE ({processed}/{self.total_questions_planned}) for {self.strategy_name}" + "="*20)
        print(f"⏱️  Avg Time/Q: {avg_time:.1f}s | Total Time: {elapsed_time/60:.1f}m | ETA: {eta_minutes:.1f}m")
        print(f"💰 Total Tokens: {self.total_tokens_used['total']:,} | API Calls: {self.total_api_calls:,}")
        print(f"📉 Failed Questions (no answer): {self.total_failed_questions}")
        print(f"   🎯 Simple Majority Accuracy: {acc}% ({self.correct_counts}/{processed})")
        print("="*100)

# --- Initialize Global Client ---
client = OpenRouterClient()

# --- Utility Functions ---
def load_aime_dataset_from_hf(dataset_id: str):
    """Load AIME dataset from Hugging Face."""
    try:
        print(f"📚 Loading dataset '{dataset_id}' from Hugging Face...")
        df = load_dataset(dataset_id)['train'].to_pandas()
        print("✅ Dataset loaded successfully.")
        return df
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return pd.DataFrame()

def call_model_with_rate_limiting(model_name: str, messages: list, max_tokens: int, temperature: float) -> dict | None:
    """Call model with retry logic and rate limiting."""
    current_delay = RATE_LIMIT_DELAY
    for attempt in range(MAX_RETRIES):
        try:
            time.sleep(current_delay)
            response = client.call_model(model_name, messages, max_tokens, temperature)
            response.raise_for_status()
            return response.json()

        except requests.exceptions.JSONDecodeError:
            print(f"    ❌ JSONDecodeError on attempt {attempt+1}/{MAX_RETRIES}.")
            print(f"    Raw Response Text (Status {response.status_code}):\\n---")
            print(response.text)
            print("---\\n    Retrying...")
            current_delay *= BACKOFF_MULTIPLIER
            continue
        except requests.exceptions.HTTPError as e:
            print(f"❌ HTTP Error on attempt {attempt+1}/{MAX_RETRIES}: {e}\\n    Response Body: {e.response.text}")
            current_delay *= BACKOFF_MULTIPLIER
        except requests.exceptions.RequestException as e:
            print(f"❌ Request Error on attempt {attempt+1}/{MAX_RETRIES}: {e}")
            current_delay *= BACKOFF_MULTIPLIER
    print(f"❌ Error: Max retries ({MAX_RETRIES}) exceeded for model {model_name}.")
    return None

def extract_answer(raw_response: dict) -> str | None:
    """Extract the final answer from model response."""
    if not raw_response or not raw_response.get('choices'): return None
    response_text = raw_response['choices'][0].get('message', {}).get('content', '')
    if not response_text: return None
    boxed_match = re.search(r'\\\\boxed{([^}]+)}', response_text)
    if boxed_match: return boxed_match.group(1).strip()
    answer_tag_match = re.search(r"<answer>\\s*(.*?)\\s*</answer>", response_text, re.DOTALL)
    if answer_tag_match: return answer_tag_match.group(1).strip()
    return None

# --- Core Strategy Function ---
def run_parallel_strategy(model_name: str, question_text: str, num_chains: int, metrics: MetricsTracker):
    """Execute parallel reasoning strategy."""
    system_prompt = """You are a world-class mathematician and an expert in solving problems from the American Invitational Mathematics Examination (AIME). Your task is to solve the given problem with exceptional rigor and clarity.

Follow these instructions precisely:
1.  **Deconstruct the Problem:** Read the problem carefully. Identify the core mathematical concepts involved. State your initial interpretation and the goal.
2.  **Think Step-by-Step:** Use `<think>` tags to enclose your entire reasoning process. Work through the problem logically. Show all calculations and explain *why* you are taking each step.
3.  **Final Answer Formulation:** After your reasoning, provide the final answer. The answer to an AIME problem is always an integer between 000 and 999. You MUST enclose the final answer in \\\\boxed{} and <answer> tags. For example: `\\\\boxed{123}` and `<answer>123</answer>`."""
    
    initial_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Please solve the following AIME problem:\\n\\n{question_text}"}
    ]
    
    api_responses = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_chains) as executor:
        future_to_chain = {
            executor.submit(call_model_with_rate_limiting, model_name, initial_messages, MAX_TOKENS_PER_STEP, temperature=0.4 + (i * 0.05)): i 
            for i in range(num_chains)
        }
        
        for future in concurrent.futures.as_completed(future_to_chain):
            chain_index = future_to_chain[future]
            try:
                response = future.result()
                print(f"\\n  📍 Parallel Chain {chain_index+1}/{num_chains} completed.")
                if response:
                    api_responses.append(response)
                    metrics.log_api_call(response.get('usage'))
                    
                    chain_answer = extract_answer(response)
                    usage = response.get('usage', {})
                    token_str = f"In:{usage.get('prompt_tokens',0)}/Out:{usage.get('completion_tokens',0)}/Total:{usage.get('total_tokens',0)}"
                    print(f"    ➡️ Chain {chain_index+1} Answer: {chain_answer} | Tokens Used: {token_str}")
                    
            except Exception as exc:
                print(f"  ❌ Parallel chain {chain_index+1} generated an exception: {exc}")
    return api_responses

# --- Analysis Function ---
def analyze_and_evaluate(api_responses: list, ground_truth: str):
    """Analyze responses and evaluate using majority voting."""
    if not api_responses: return {"error": "No API responses."}
    
    extracted_answers = [extract_answer(resp) for resp in api_responses if resp is not None]
    valid_answers = [ans for ans in extracted_answers if ans is not None]

    if not valid_answers:
        return {'voted_answer': "N/A", 'evaluation': "Incorrect", 'has_valid_answer': False}
    
    majority_vote = Counter(valid_answers).most_common(1)[0][0]
    evaluation = "Correct" if majority_vote == ground_truth else "Incorrect"
    
    print(f"\\n🎯 QUESTION RESULTS (Ground Truth: {ground_truth})")
    status = "✅" if evaluation == "Correct" else "❌"
    print(f"   {status} Simple Majority Vote: {majority_vote} ({evaluation})")
    print(f"   (Individual answers from chains: {valid_answers})")

    return {'voted_answer': majority_vote, 'evaluation': evaluation, 'has_valid_answer': True}

# --- Main Execution Block ---
if __name__ == "__main__":
    test_df = load_aime_dataset_from_hf(HUGGING_FACE_DATASET_ID)
    if test_df.empty: exit()

    strategy_name = f"{MODEL_TO_TEST.split('/')[-1]}_parallel_{MAX_CHAINS}chains"
    metrics = MetricsTracker(MODEL_TO_TEST, strategy_name)
    metrics.start_experiment(len(test_df))
    all_question_results = []

    for index, row in test_df.iterrows():
        problem_id = row['ID']
        question_text = row['Problem']
        ground_truth_answer = str(row['Answer'])
        
        print("\\n" + "🔸" * 50 + f"\\n📝 Processing Question {problem_id} ({index + 1}/{len(test_df)})")
        print(f"   Correct Answer is: {ground_truth_answer}")
        
        api_log = run_parallel_strategy(MODEL_TO_TEST, question_text, MAX_CHAINS, metrics)
        analysis = analyze_and_evaluate(api_log, ground_truth_answer)
        
        if not analysis.get('has_valid_answer'):
            metrics.log_failed_question()
            
        metrics.complete_question(analysis.get('evaluation', 'Incorrect'))
        metrics.print_running_metrics()
        
        all_question_results.append({ "problem_id": problem_id, "analysis_results": analysis, "api_log": api_log })
    
    # --- Final Report ---
    processed_count = metrics.total_questions_processed
    final_accuracy = round(metrics.correct_counts / processed_count * 100, 2) if processed_count > 0 else 0
    
    final_report = {
        "experiment_summary": { 
            "model": MODEL_TO_TEST, 
            "strategy": "parallel", 
            "max_chains": MAX_CHAINS, 
            "dataset": HUGGING_FACE_DATASET_ID, 
            "total_questions_processed": processed_count, 
            "total_failed_questions": metrics.total_failed_questions, 
            "total_api_calls": metrics.total_api_calls, 
            "total_tokens_used": metrics.total_tokens_used, 
            "average_tokens_per_question": metrics.total_tokens_used['total'] / processed_count if processed_count > 0 else 0, 
            "average_time_per_question_sec": (time.time() - metrics.start_time) / processed_count if processed_count > 0 else 0, 
        },
        "final_accuracy_percent": {"simple_majority": final_accuracy},
        "detailed_results_per_question": all_question_results
    }
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"AIME24_Parallel_{MAX_CHAINS}chains_{strategy_name}_{timestamp}.json"
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(final_report, f, indent=2, ensure_ascii=False)
        
    print(f"\\n🏁🏁🏁 PARALLEL EXPERIMENT COMPLETE 🏁🏁🏁\\n📁 Full report saved to: {filename}")
    print("\\n--- FINAL SUMMARY ---")
    print(json.dumps(final_report["experiment_summary"], indent=2))