#!/usr/bin/env python
"""
AIME 2025 Sequential Reasoning Experiment Template
ICLR 2026 Submission Code

This script implements sequential test-time scaling for AIME 2025 problems.
Generates iterative reasoning chains with multiple voting methods including entropy weighting.
"""

import os
import requests
import json
import re
from collections import Counter
import time
from datetime import datetime
import pandas as pd
from datasets import load_dataset
import random
import math

# --- Main Experiment Configuration ---
MODEL_TO_TEST = "openai/gpt-oss-120b"  # Replace with desired model
MAX_STEPS = 3  # Number of sequential reasoning steps (3, 6, or 9)

# --- Configuration ---
API_KEY = "YOUR_OPENROUTER_KEY"  # Replace with your OpenRouter API key
HUGGING_FACE_DATASET_ID = "Maxwell-Jia/AIME_2025"  # AIME 2025 dataset
RATE_LIMIT_DELAY = 0.5
MAX_RETRIES = 5
BACKOFF_MULTIPLIER = 2
MAX_TOKENS_PER_STEP = 4096

# --- OpenRouter Client Class ---
class OpenRouterClient:
    """OpenRouter API client for model inference."""
    
    def __init__(self, api_key=None):
        self.api_key = api_key or API_KEY
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://github.com/anonymous-repo",
            "X-Title": "AIME 2025 Sequential Reasoning",
        }
    
    def call_model(self, model_name, messages, max_tokens, temperature=0.7, top_p=0.8):
        """Constructs and sends the API request payload with logprobs for entropy calculation."""
        payload = {
            "model": model_name,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "stream": False,
            "logprobs": True,
            "top_logprobs": 5,
            "reasoning": {"effort": "high"},
            "provider": {"order": ["deepinfra/fp4"], "require_parameters": True, "allow_fallbacks": False}
        }
        return requests.post(self.url, headers=self.headers, json=payload, timeout=240)

# All voting methods to be tracked
ALL_VOTING_METHODS = [
    'raw_answer_at_end', 'simple_majority', 'linear_increase', 
    'exp_increase', 'inv_rank_increase', 'linear_decay', 
    'exp_decay', 'inv_rank_decay', 'entropy_weighted'
]

# --- Metrics Tracking ---
class MetricsTracker:
    def __init__(self, model_name, strategy_name):
        self.model_name = model_name
        self.strategy_name = strategy_name
        self.total_questions_processed = 0
        self.total_questions_planned = 0
        self.correct_counts = Counter({method: 0 for method in ALL_VOTING_METHODS})
        self.total_tokens_used = {"input": 0, "output": 0, "total": 0}
        self.total_api_calls = 0
        self.total_failed_questions = 0
        self.start_time = None
        
    def start_experiment(self, total_questions):
        self.start_time = time.time()
        self.total_questions_planned = total_questions
        self.print_header()
        
    def log_api_call(self, usage_dict):
        self.total_api_calls += 1
        if usage_dict:
            self.total_tokens_used["input"] += usage_dict.get("prompt_tokens", 0)
            self.total_tokens_used["output"] += usage_dict.get("completion_tokens", 0)
            self.total_tokens_used["total"] += usage_dict.get("total_tokens", 0)
            
    def log_failed_question(self):
        self.total_failed_questions += 1
        
    def complete_question(self, evaluation_results):
        self.total_questions_processed += 1
        for method, result in evaluation_results.items():
            if result == "Correct":
                self.correct_counts[method] += 1
        
    def print_header(self):
        print("\\n" + "="*100)
        print(f"🚀 STARTING SEQUENTIAL EXPERIMENT: {self.strategy_name}")
        print(f"🤖 Model: {self.model_name} | Max Steps: {MAX_STEPS}")
        print("="*100)
        
    def print_running_metrics(self):
        processed = self.total_questions_processed
        if processed == 0: return
        elapsed_time = time.time() - self.start_time
        avg_time = elapsed_time / processed
        eta_minutes = (self.total_questions_planned - processed) * avg_time / 60
        
        print("\\n" + f"🔄 LIVE METRICS UPDATE ({processed}/{self.total_questions_planned}) for {self.strategy_name}" + "="*20)
        print(f"⏱️  Avg Time/Q: {avg_time:.1f}s | Total Time: {elapsed_time/60:.1f}m | ETA: {eta_minutes:.1f}m")
        print(f"💰 Total Tokens: {self.total_tokens_used['total']:,} | API Calls: {self.total_api_calls:,}")
        print(f"📉 Failed Questions (no answer): {self.total_failed_questions}")
        print("   --- Accuracy by Voting Method ---")
        for method, count in sorted(self.correct_counts.items()):
            acc = round(count / processed * 100, 2) if processed > 0 else 0
            print(f"   🎯 {method.replace('_', ' ').title():<30}: {acc}% ({count}/{processed})")
        print("="*100)

# --- Initialize Global Client ---
client = OpenRouterClient()

# --- Utility Functions ---
def load_aime_dataset_from_hf(dataset_id: str):
    """Load AIME dataset from Hugging Face."""
    try:
        print(f"📚 Loading dataset '{dataset_id}' from Hugging Face...")
        df = load_dataset(dataset_id)['train'].to_pandas()
        print("✅ Dataset loaded successfully.")
        return df
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return pd.DataFrame()

def call_model_with_rate_limiting(model_name: str, messages: list, max_tokens: int) -> dict | None:
    """Call model with retry logic and rate limiting."""
    current_delay = RATE_LIMIT_DELAY
    for attempt in range(MAX_RETRIES):
        try:
            time.sleep(current_delay)
            response = client.call_model(model_name, messages, max_tokens)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.JSONDecodeError:
            print(f"    ❌ JSONDecodeError on attempt {attempt+1}/{MAX_RETRIES}.")
            print(f"    Raw Response Text (Status {response.status_code}):\\n---")
            print(response.text)
            print("---\\n    Retrying...")
            current_delay *= BACKOFF_MULTIPLIER
            continue
        except requests.exceptions.HTTPError as e:
            print(f"❌ HTTP Error on attempt {attempt+1}/{MAX_RETRIES}: {e}\\n    Response Body: {e.response.text}")
            current_delay *= BACKOFF_MULTIPLIER
        except requests.exceptions.RequestException as e:
            print(f"❌ Request Error on attempt {attempt+1}/{MAX_RETRIES}: {e}")
            current_delay *= BACKOFF_MULTIPLIER
    print(f"❌ Error: Max retries ({MAX_RETRIES}) exceeded for model {model_name}.")
    return None

def extract_answer(raw_response: dict) -> str | None:
    """Extract the final answer from model response."""
    if not raw_response or not raw_response.get('choices'): return None
    response_text = raw_response['choices'][0].get('message', {}).get('content', '')
    if not response_text: return None
    boxed_match = re.search(r'\\\\boxed{([^}]+)}', response_text)
    if boxed_match: return boxed_match.group(1).strip()
    answer_tag_match = re.search(r"<answer>\\s*(.*?)\\s*</answer>", response_text, re.DOTALL)
    if answer_tag_match: return answer_tag_match.group(1).strip()
    return None

def calculate_mean_sequence_entropy(api_response: dict) -> float | None:
    """Calculate mean sequence entropy from logprobs."""
    if not api_response: return None
    try:
        logprobs_content = api_response['choices'][0]['logprobs']['content']
        if not logprobs_content: return None
    except (KeyError, TypeError, IndexError):
        return None
    token_entropies = []
    for token_info in logprobs_content:
        top_logprobs = token_info.get('top_logprobs', [])
        if not top_logprobs: continue
        log_probs = [item['logprob'] for item in top_logprobs]
        probs = [math.exp(lp) for lp in log_probs]
        total_prob = sum(probs)
        if total_prob == 0: continue
        normalized_probs = [p / total_prob for p in probs]
        entropy = -sum(p * math.log2(p) for p in normalized_probs if p > 0)
        token_entropies.append(entropy)
    if not token_entropies: return None
    return sum(token_entropies) / len(token_entropies)

def apply_voting_methods(answers: list, entropies: list) -> dict:
    """Apply all voting methods including entropy weighting."""
    if not answers: return {}
    
    methods = {}
    valid_answers = [ans for ans in answers if ans is not None]
    if not valid_answers: return methods

    methods['raw_answer_at_end'] = valid_answers[-1]
    methods['simple_majority'] = Counter(valid_answers).most_common(1)[0][0]
    
    scores = {
        'linear_increase': Counter(), 'exp_increase': Counter(), 'inv_rank_increase': Counter(),
        'linear_decay': Counter(), 'exp_decay': Counter(), 'inv_rank_decay': Counter(),
        'entropy_weighted': Counter()
    }
    
    n = len(valid_answers)
    for i, answer in enumerate(valid_answers):
        scores['linear_increase'][answer] += (i + 1)
        scores['exp_increase'][answer] += 2**i
        if n > i: scores['inv_rank_increase'][answer] += 1 / (n - i)
        scores['linear_decay'][answer] += (n - i)
        scores['exp_decay'][answer] += 2**(n - 1 - i)
        scores['inv_rank_decay'][answer] += 1 / (i + 1)

    # Entropy-weighted logic
    for i, answer in enumerate(answers):
        if answer is not None and i < len(entropies):
            entropy = entropies[i]
            if entropy is not None and entropy > 1e-9:
                weight = 1.0 / entropy
                scores['entropy_weighted'][answer] += weight

    for name, counter in scores.items():
        if counter: methods[name] = counter.most_common(1)[0][0]
        
    return methods

# --- Core Strategy Function ---
def run_sequential_strategy(model_name: str, question_text: str, num_steps: int, metrics: MetricsTracker):
    """Execute sequential reasoning strategy."""
    system_prompt = """You are a world-class mathematician and an expert in solving problems from the American Invitational Mathematics Examination (AIME). Your task is to solve the given problem with exceptional rigor and clarity.

Follow these instructions precisely:
1.  **Deconstruct the Problem:** Read the problem carefully. Identify the core mathematical concepts involved. State your initial interpretation and the goal.
2.  **Think Step-by-Step:** Use `<think>` tags to enclose your entire reasoning process. Work through the problem logically. Show all calculations and explain *why* you are taking each step.
3.  **Final Answer Formulation:** After your reasoning, provide the final answer. The answer to an AIME problem is always an integer between 000 and 999. You MUST enclose the final answer in \\\\boxed{} and <answer> tags. For example: `\\\\boxed{123}` and `<answer>123</answer>`."""
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Please solve the following AIME problem:\\n\\n{question_text}"}
    ]
    api_responses = []

    for i in range(num_steps):
        print(f"\\n  📍 Step {i+1}/{num_steps}...")
        raw_response = call_model_with_rate_limiting(model_name, messages, MAX_TOKENS_PER_STEP)
        if not raw_response:
            print("    ❗️ Breaking chain due to API call failure.")
            break
        
        api_responses.append(raw_response)
        metrics.log_api_call(raw_response.get('usage'))
        
        step_answer = extract_answer(raw_response)
        usage = raw_response.get('usage', {})
        token_str = f"In:{usage.get('prompt_tokens',0)}/Out:{usage.get('completion_tokens',0)}/Total:{usage.get('total_tokens',0)}"
        print(f"    ➡️ Step {i+1} Answer: {step_answer} | Tokens Used: {token_str}")

        response_text = raw_response.get('choices', [{}])[0].get('message', {}).get('content', '')
        messages.append({"role": "assistant", "content": response_text})
        
        refinement_prompt = "Wait, continue your analysis. Review your previous reasoning, identify any gaps or errors, and verify your approach to reach a more confident conclusion. Remember to put your final answer within \\\\boxed{} and <answer> tags."
        messages.append({"role": "user", "content": refinement_prompt})
        
    return api_responses

# --- Analysis Function ---
def analyze_and_evaluate(api_responses: list, ground_truth: str):
    """Analyze responses and evaluate using all voting methods."""
    if not api_responses: return {"error": "No API responses."}
    
    step_by_step_answers = [extract_answer(resp) for resp in api_responses]
    step_by_step_entropies = [calculate_mean_sequence_entropy(resp) for resp in api_responses]

    if not any(ans is not None for ans in step_by_step_answers):
        return {'voted_answers': {}, 'evaluation': {}, 'has_valid_answer': False, 'step_by_step_eval': [], 'step_by_step_entropies': step_by_step_entropies}

    voted_answers = apply_voting_methods(step_by_step_answers, step_by_step_entropies)
    evaluation = {method: "Correct" if voted_answers.get(method) == ground_truth else "Incorrect" for method in voted_answers}
    
    step_by_step_eval = [{"step": i+1, "answer": ans, "evaluation": "Correct" if ans == ground_truth else "Incorrect"} for i, ans in enumerate(step_by_step_answers)]

    print(f"\\n🎯 QUESTION RESULTS (Ground Truth: {ground_truth})")
    print("   --- Final Voted Answers ---")
    for method, vote in sorted(voted_answers.items()):
        result = evaluation.get(method, "N/A")
        status = "✅" if result == "Correct" else "❌"
        print(f"   {status} {method.replace('_', ' ').title():<30}: {vote}")
    
    print("\\n   --- Step-by-Step Raw Answers (Recap) ---")
    for step_eval in step_by_step_eval:
        status = "✅" if step_eval['evaluation'] == "Correct" else "❌"
        print(f"   {status} Step {step_eval['step']}: {step_eval['answer']}")

    print(f"\\n   --- Step-by-Step Entropies (Recap) ---")
    print(f"   {[f'{e:.2f}' if e is not None else 'N/A' for e in step_by_step_entropies]}")

    return {
        'voted_answers': voted_answers, 
        'evaluation': evaluation, 
        'has_valid_answer': True, 
        'step_by_step_eval': step_by_step_eval,
        'step_by_step_entropies': step_by_step_entropies
    }

# --- Main Execution Block ---
if __name__ == "__main__":
    test_df = load_aime_dataset_from_hf(HUGGING_FACE_DATASET_ID)
    if test_df.empty: exit()

    strategy_name = f"{MODEL_TO_TEST.split('/')[-1]}_sequential_{MAX_STEPS}steps_entropy"
    metrics = MetricsTracker(MODEL_TO_TEST, strategy_name)
    metrics.start_experiment(len(test_df))
    all_question_results = []

    for index, row in test_df.iterrows():
        problem_id = row['ID']
        question_text = row['Problem']
        ground_truth_answer = str(row['Answer'])
        
        print("\\n" + "🔸" * 50 + f"\\n📝 Processing Question {problem_id} ({index + 1}/{len(test_df)})")
        print(f"   Correct Answer is: {ground_truth_answer}")
        
        api_log = run_sequential_strategy(MODEL_TO_TEST, question_text, MAX_STEPS, metrics)
        analysis = analyze_and_evaluate(api_log, ground_truth_answer)
        
        if not analysis.get('has_valid_answer'):
            metrics.log_failed_question()
        
        metrics.complete_question(analysis.get('evaluation', {}))
        metrics.print_running_metrics()
        
        all_question_results.append({
            "problem_id": problem_id,
            "analysis_results": analysis,
            "api_log": api_log
        })
    
    # --- Final Report ---
    processed_count = metrics.total_questions_processed
    final_accuracies = {
        method: f"{round(count / processed_count * 100, 2)}%" 
        for method, count in sorted(metrics.correct_counts.items()) 
        if processed_count > 0
    }
    
    final_report = {
        "experiment_summary": { 
            "model": MODEL_TO_TEST, 
            "strategy": "sequential", 
            "max_steps": MAX_STEPS, 
            "dataset": HUGGING_FACE_DATASET_ID, 
            "total_questions_processed": processed_count, 
            "total_failed_questions": metrics.total_failed_questions, 
            "total_api_calls": metrics.total_api_calls, 
            "total_tokens_used": metrics.total_tokens_used, 
            "average_tokens_per_question": metrics.total_tokens_used['total'] / processed_count if processed_count > 0 else 0, 
            "average_time_per_question_sec": (time.time() - metrics.start_time) / processed_count if processed_count > 0 else 0 
        },
        "final_accuracies_percent": final_accuracies,
        "detailed_results_per_question": all_question_results
    }
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"AIME25_Sequential_{MAX_STEPS}steps_{strategy_name}_{timestamp}.json"
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(final_report, f, indent=2, ensure_ascii=False)
        
    print(f"\\n🏁🏁🏁 SEQUENTIAL EXPERIMENT COMPLETE 🏁🏁🏁\\n📁 Full report saved to: {filename}")
    print("\\n--- FINAL SUMMARY ---")
    print(json.dumps(final_report["experiment_summary"], indent=2))
    print("\\n--- FINAL ACCURACIES ---")
    print(json.dumps(final_report["final_accuracies_percent"], indent=2))