#!/usr/bin/env python
"""
GPQA Diamond Sequential Reasoning Experiment Template
ICLR 2026 Submission Code

This script implements sequential test-time scaling for GPQA Diamond problems.
Generates iterative reasoning chains with multiple voting methods including entropy weighting.
"""

import os
import requests
import json
import re
from collections import Counter
import time
from datetime import datetime
import pandas as pd
from datasets import load_dataset
import random
import math

# --- Main Experiment Configuration ---
MODEL_TO_TEST = "openai/gpt-oss-120b"  # Replace with desired model
MAX_STEPS = 3  # Number of sequential reasoning steps (3, 6, or 9)

# --- Configuration ---
API_KEY = "YOUR_OPENROUTER_KEY"  # Replace with your OpenRouter API key
HUGGING_FACE_DATASET_ID = "Idavidrein/gpqa"  # GPQA Diamond dataset
RATE_LIMIT_DELAY = 0.5
MAX_RETRIES = 5
BACKOFF_MULTIPLIER = 2
MAX_TOKENS_PER_STEP = 4096

# --- OpenRouter Client Class ---
class OpenRouterClient:
    """OpenRouter API client for model inference."""
    
    def __init__(self, api_key=None):
        self.api_key = api_key or API_KEY
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://github.com/anonymous-repo",
            "X-Title": "GPQA Diamond Sequential Reasoning",
        }
    
    def call_model(self, model_name, messages, max_tokens, temperature=0.7, top_p=0.8):
        """Constructs and sends the API request payload with logprobs for entropy calculation."""
        payload = {
            "model": model_name,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "stream": False,
            "logprobs": True,
            "top_logprobs": 5,
            "reasoning": {"effort": "high"},
            "provider": {"order": ["deepinfra/fp4"], "require_parameters": True, "allow_fallbacks": False}
        }
        return requests.post(self.url, headers=self.headers, json=payload, timeout=240)

# All voting methods to be tracked
ALL_VOTING_METHODS = [
    'raw_answer_at_end', 'simple_majority', 'linear_increase', 
    'exp_increase', 'inv_rank_increase', 'linear_decay', 
    'exp_decay', 'inv_rank_decay', 'entropy_weighted'
]

# --- Metrics Tracking ---
class MetricsTracker:
    def __init__(self, model_name, strategy_name):
        self.model_name = model_name
        self.strategy_name = strategy_name
        self.total_questions_processed = 0
        self.total_questions_planned = 0
        self.correct_counts = Counter({method: 0 for method in ALL_VOTING_METHODS})
        self.total_tokens_used = {"input": 0, "output": 0, "total": 0}
        self.total_api_calls = 0
        self.total_failed_questions = 0
        self.start_time = None
        
    def start_experiment(self, total_questions):
        self.start_time = time.time()
        self.total_questions_planned = total_questions
        self.print_header()
        
    def log_api_call(self, usage_dict):
        self.total_api_calls += 1
        if usage_dict:
            self.total_tokens_used["input"] += usage_dict.get("prompt_tokens", 0)
            self.total_tokens_used["output"] += usage_dict.get("completion_tokens", 0)
            self.total_tokens_used["total"] += usage_dict.get("total_tokens", 0)
            
    def log_failed_question(self):
        self.total_failed_questions += 1
        
    def complete_question(self, evaluation_results):
        self.total_questions_processed += 1
        for method, result in evaluation_results.items():
            if result == "Correct":
                self.correct_counts[method] += 1
        
    def print_header(self):
        print("\\n" + "="*100)
        print(f"🚀 STARTING SEQUENTIAL EXPERIMENT: {self.strategy_name}")
        print(f"🤖 Model: {self.model_name} | Max Steps: {MAX_STEPS}")
        print("="*100)
        
    def print_running_metrics(self):
        processed = self.total_questions_processed
        if processed == 0: return
        elapsed_time = time.time() - self.start_time
        avg_time = elapsed_time / processed
        eta_minutes = (self.total_questions_planned - processed) * avg_time / 60
        
        print("\\n" + f"🔄 LIVE METRICS UPDATE ({processed}/{self.total_questions_planned}) for {self.strategy_name}" + "="*20)
        print(f"⏱️  Avg Time/Q: {avg_time:.1f}s | Total Time: {elapsed_time/60:.1f}m | ETA: {eta_minutes:.1f}m")
        print(f"💰 Total Tokens: {self.total_tokens_used['total']:,} | API Calls: {self.total_api_calls:,}")
        print(f"📉 Failed Questions (no answer): {self.total_failed_questions}")
        print("   --- Accuracy by Voting Method ---")
        for method, count in sorted(self.correct_counts.items()):
            acc = round(count / processed * 100, 2) if processed > 0 else 0
            print(f"   🎯 {method.replace('_', ' ').title():<30}: {acc}% ({count}/{processed})")
        print("="*100)

# --- Initialize Global Client ---
client = OpenRouterClient()

# --- Utility Functions ---
def load_gpqa_dataset_from_hf(dataset_id: str):
    """Load GPQA dataset from Hugging Face."""
    try:
        print(f"📚 Loading dataset '{dataset_id}' from Hugging Face...")
        dataset = load_dataset(dataset_id, "gpqa_diamond")
        df = dataset['train'].to_pandas()
        print("✅ Dataset loaded successfully.")
        return df
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return pd.DataFrame()

def call_model_with_rate_limiting(model_name: str, messages: list, max_tokens: int) -> dict | None:
    """Call model with retry logic and rate limiting."""
    current_delay = RATE_LIMIT_DELAY
    for attempt in range(MAX_RETRIES):
        try:
            time.sleep(current_delay)
            response = client.call_model(model_name, messages, max_tokens)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.JSONDecodeError:
            print(f"    ❌ JSONDecodeError on attempt {attempt+1}/{MAX_RETRIES}.")
            print(f"    Raw Response Text (Status {response.status_code}):\\n---")
            print(response.text)
            print("---\\n    Retrying...")
            current_delay *= BACKOFF_MULTIPLIER
            continue
        except requests.exceptions.HTTPError as e:
            print(f"❌ HTTP Error on attempt {attempt+1}/{MAX_RETRIES}: {e}\\n    Response Body: {e.response.text}")
            current_delay *= BACKOFF_MULTIPLIER
        except requests.exceptions.RequestException as e:
            print(f"❌ Request Error on attempt {attempt+1}/{MAX_RETRIES}: {e}")
            current_delay *= BACKOFF_MULTIPLIER
    print(f"❌ Error: Max retries ({MAX_RETRIES}) exceeded for model {model_name}.")
    return None

def extract_answer(raw_response: dict) -> str | None:
    """Extract the final answer from model response."""
    if not raw_response or not raw_response.get('choices'): return None
    response_text = raw_response['choices'][0].get('message', {}).get('content', '')
    if not response_text: return None
    
    # Look for various answer patterns
    answer_patterns = [
        r'(?:answer|final answer).*?(?:is|:)\\s*\\(?([A-D])\\)?',
        r'\\(?([A-D])\\)?\\s*(?:is|$)',
        r'(?:option|choice)\\s*\\(?([A-D])\\)?',
        r'\\bthe\\s+answer\\s+is\\s+\\(?([A-D])\\)?',
        r'\\bcorrect\\s+answer\\s+is\\s+\\(?([A-D])\\)?'
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, response_text, re.IGNORECASE)
        if match:
            return match.group(1).upper()
    
    return None

def calculate_mean_sequence_entropy(api_response: dict) -> float | None:
    """Calculate mean sequence entropy from logprobs."""
    if not api_response: return None
    try:
        logprobs_content = api_response['choices'][0]['logprobs']['content']
        if not logprobs_content: return None
    except (KeyError, TypeError, IndexError):
        return None
    token_entropies = []
    for token_info in logprobs_content:
        top_logprobs = token_info.get('top_logprobs', [])
        if not top_logprobs: continue
        log_probs = [item['logprob'] for item in top_logprobs]
        probs = [math.exp(lp) for lp in log_probs]
        total_prob = sum(probs)
        if total_prob == 0: continue
        normalized_probs = [p / total_prob for p in probs]
        entropy = -sum(p * math.log2(p) for p in normalized_probs if p > 0)
        token_entropies.append(entropy)
    if not token_entropies: return None
    return sum(token_entropies) / len(token_entropies)

def apply_voting_methods(answers: list, entropies: list) -> dict:
    """Apply all voting methods including entropy weighting."""
    if not answers: return {}
    
    methods = {}
    valid_answers = [ans for ans in answers if ans is not None]
    if not valid_answers: return methods

    methods['raw_answer_at_end'] = valid_answers[-1]
    methods['simple_majority'] = Counter(valid_answers).most_common(1)[0][0]
    
    scores = {
        'linear_increase': Counter(), 'exp_increase': Counter(), 'inv_rank_increase': Counter(),
        'linear_decay': Counter(), 'exp_decay': Counter(), 'inv_rank_decay': Counter(),
        'entropy_weighted': Counter()
    }
    
    n = len(valid_answers)
    for i, answer in enumerate(valid_answers):
        scores['linear_increase'][answer] += (i + 1)
        scores['exp_increase'][answer] += 2**i
        if n > i: scores['inv_rank_increase'][answer] += 1 / (n - i)
        scores['linear_decay'][answer] += (n - i)
        scores['exp_decay'][answer] += 2**(n - 1 - i)
        scores['inv_rank_decay'][answer] += 1 / (i + 1)

    # Entropy-weighted logic
    for i, answer in enumerate(answers):
        if answer is not None and i < len(entropies):
            entropy = entropies[i]
            if entropy is not None and entropy > 1e-9:
                weight = 1.0 / entropy
                scores['entropy_weighted'][answer] += weight

    for name, counter in scores.items():
        if counter: methods[name] = counter.most_common(1)[0][0]
        
    return methods

# --- Core Strategy Function ---
def run_sequential_strategy(model_name: str, question_data: dict, num_steps: int, metrics: MetricsTracker):
    """Execute sequential reasoning strategy."""
    system_prompt = """You are an expert in graduate-level physics, chemistry, and biology. Your task is to solve challenging scientific questions that require deep understanding and careful reasoning.

Follow these instructions precisely:
1. **Analyze the Question:** Read the question carefully and identify the key concepts, principles, and information provided.
2. **Think Step-by-Step:** Work through the problem systematically, explaining your reasoning at each step.
3. **Consider All Options:** Evaluate each multiple choice option carefully.
4. **Final Answer:** Provide your final answer as a single letter (A, B, C, or D).

Be thorough in your analysis and double-check your reasoning before providing the final answer."""
    
    question_text = question_data['Question']
    choices = f"A) {question_data['Correct Answer']}\\nB) {question_data['Incorrect Answer 1']}\\nC) {question_data['Incorrect Answer 2']}\\nD) {question_data['Incorrect Answer 3']}"
    
    full_question = f"{question_text}\\n\\n{choices}"
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Please solve the following scientific question:\\n\\n{full_question}"}
    ]
    api_responses = []

    for i in range(num_steps):
        print(f"\\n  📍 Step {i+1}/{num_steps}...")
        raw_response = call_model_with_rate_limiting(model_name, messages, MAX_TOKENS_PER_STEP)
        if not raw_response:
            print("    ❗️ Breaking chain due to API call failure.")
            break
        
        api_responses.append(raw_response)
        metrics.log_api_call(raw_response.get('usage'))
        
        step_answer = extract_answer(raw_response)
        usage = raw_response.get('usage', {})
        token_str = f"In:{usage.get('prompt_tokens',0)}/Out:{usage.get('completion_tokens',0)}/Total:{usage.get('total_tokens',0)}"
        print(f"    ➡️ Step {i+1} Answer: {step_answer} | Tokens Used: {token_str}")

        response_text = raw_response.get('choices', [{}])[0].get('message', {}).get('content', '')
        messages.append({"role": "assistant", "content": response_text})
        
        refinement_prompt = "Please review your reasoning carefully. Consider if there are any errors in your analysis, alternative approaches, or additional factors you should consider. Provide your refined answer after this careful review."
        messages.append({"role": "user", "content": refinement_prompt})
        
    return api_responses

# --- Analysis Function ---
def analyze_and_evaluate(api_responses: list, ground_truth: str):
    """Analyze responses and evaluate using all voting methods."""
    if not api_responses: return {"error": "No API responses."}
    
    step_by_step_answers = [extract_answer(resp) for resp in api_responses]
    step_by_step_entropies = [calculate_mean_sequence_entropy(resp) for resp in api_responses]

    if not any(ans is not None for ans in step_by_step_answers):
        return {'voted_answers': {}, 'evaluation': {}, 'has_valid_answer': False, 'step_by_step_eval': [], 'step_by_step_entropies': step_by_step_entropies}

    voted_answers = apply_voting_methods(step_by_step_answers, step_by_step_entropies)
    evaluation = {method: "Correct" if voted_answers.get(method) == ground_truth else "Incorrect" for method in voted_answers}
    
    step_by_step_eval = [{"step": i+1, "answer": ans, "evaluation": "Correct" if ans == ground_truth else "Incorrect"} for i, ans in enumerate(step_by_step_answers)]

    print(f"\\n🎯 QUESTION RESULTS (Ground Truth: {ground_truth})")
    print("   --- Final Voted Answers ---")
    for method, vote in sorted(voted_answers.items()):
        result = evaluation.get(method, "N/A")
        status = "✅" if result == "Correct" else "❌"
        print(f"   {status} {method.replace('_', ' ').title():<30}: {vote}")
    
    print("\\n   --- Step-by-Step Raw Answers (Recap) ---")
    for step_eval in step_by_step_eval:
        status = "✅" if step_eval['evaluation'] == "Correct" else "❌"
        print(f"   {status} Step {step_eval['step']}: {step_eval['answer']}")

    print(f"\\n   --- Step-by-Step Entropies (Recap) ---")
    print(f"   {[f'{e:.2f}' if e is not None else 'N/A' for e in step_by_step_entropies]}")

    return {
        'voted_answers': voted_answers, 
        'evaluation': evaluation, 
        'has_valid_answer': True, 
        'step_by_step_eval': step_by_step_eval,
        'step_by_step_entropies': step_by_step_entropies
    }

# --- Main Execution Block ---
if __name__ == "__main__":
    test_df = load_gpqa_dataset_from_hf(HUGGING_FACE_DATASET_ID)
    if test_df.empty: exit()

    strategy_name = f"{MODEL_TO_TEST.split('/')[-1]}_sequential_{MAX_STEPS}steps_entropy"
    metrics = MetricsTracker(MODEL_TO_TEST, strategy_name)
    metrics.start_experiment(len(test_df))
    all_question_results = []

    for index, row in test_df.iterrows():
        question_id = index  # GPQA doesn't have explicit IDs
        ground_truth_answer = "A"  # Correct answer is always option A in this dataset format
        
        print("\\n" + "🔸" * 50 + f"\\n📝 Processing Question {question_id} ({index + 1}/{len(test_df)})")
        print(f"   Correct Answer is: {ground_truth_answer}")
        
        api_log = run_sequential_strategy(MODEL_TO_TEST, row, MAX_STEPS, metrics)
        analysis = analyze_and_evaluate(api_log, ground_truth_answer)
        
        if not analysis.get('has_valid_answer'):
            metrics.log_failed_question()
        
        metrics.complete_question(analysis.get('evaluation', {}))
        metrics.print_running_metrics()
        
        all_question_results.append({
            "question_id": question_id,
            "analysis_results": analysis,
            "api_log": api_log
        })
    
    # --- Final Report ---
    processed_count = metrics.total_questions_processed
    final_accuracies = {
        method: f"{round(count / processed_count * 100, 2)}%" 
        for method, count in sorted(metrics.correct_counts.items()) 
        if processed_count > 0
    }
    
    final_report = {
        "experiment_summary": { 
            "model": MODEL_TO_TEST, 
            "strategy": "sequential", 
            "max_steps": MAX_STEPS, 
            "dataset": HUGGING_FACE_DATASET_ID, 
            "total_questions_processed": processed_count, 
            "total_failed_questions": metrics.total_failed_questions, 
            "total_api_calls": metrics.total_api_calls, 
            "total_tokens_used": metrics.total_tokens_used, 
            "average_tokens_per_question": metrics.total_tokens_used['total'] / processed_count if processed_count > 0 else 0, 
            "average_time_per_question_sec": (time.time() - metrics.start_time) / processed_count if processed_count > 0 else 0 
        },
        "final_accuracies_percent": final_accuracies,
        "detailed_results_per_question": all_question_results
    }
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"GPQA_Sequential_{MAX_STEPS}steps_{strategy_name}_{timestamp}.json"
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(final_report, f, indent=2, ensure_ascii=False)
        
    print(f"\\n🏁🏁🏁 SEQUENTIAL EXPERIMENT COMPLETE 🏁🏁🏁\\n📁 Full report saved to: {filename}")
    print("\\n--- FINAL SUMMARY ---")
    print(json.dumps(final_report["experiment_summary"], indent=2))
    print("\\n--- FINAL ACCURACIES ---")
    print(json.dumps(final_report["final_accuracies_percent"], indent=2))