#!/usr/bin/env python3
import argparse
import json
import logging
import time
import os
from pathlib import Path
from typing import Dict, Any, List
from tqdm import tqdm
from openai import OpenAI

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Gemini API configuration
GEMINI_KEY = "OPENROUTER_API_KEY"  # Replace with your actual OpenRouter API key or set as environment variable
# Initialize OpenRouter client for Gemini
client = OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key=GEMINI_KEY,
)


def format_chat_message(generic_expert_prompt: str, cot: bool = True, domain: str = "math") -> List[Dict[str, str]]:
    """Format message for Gemini Chat API."""
    if cot:
        system_content = "Please strictly follow the instructions. DO NOT REASONING FOR TOO LONG AND ANSWER IN 100 TOKENS. Prioritize giving the <final> answer. Solve the following problem step-by-step and END WITH A LINE: \"Answer: <final>\" <final> should be only the correct choice(1, 2, 3, 4, etc.) ANSWER IN 100 TOKENS! ALWAYS END WITH A LINE: \"Answer: <final>\"" 
    else:
        if domain.lower() == "math":
             system_content = "Provide ONLY the final numerical answer. Do not show any work, reasoning, calculations, or explanation. Respond with just the number after \"Answer: <final>\"."
        else:
            system_content = "Please strictly follow the instructions. Provide ONLY the final answer. Do not show any work, reasoning, or explanation. Respond with just the correct choice(1, 2, 3, 4, etc.) after \"Answer: <final>\"." 
    
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": "/no_think\n" + generic_expert_prompt}
    ]
    
    return messages


def generate_answer(messages: List[Dict[str, str]], model_id: str = "qwen/qwen3-32b", max_new_tokens: int = 512) -> str:
    """Generate answer using Gemini via OpenRouter API."""
    max_retries = 3
    base_delay = 1
    
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                extra_body={},
                model=model_id,
                messages=messages,
                max_tokens=500,
                temperature=0.0,
                stop=["# END"]
            )
            
            generated_text = response.choices[0].message.content.strip()
            print(generated_text)
            return generated_text
            
        except Exception as e:
            if "rate_limit" in str(e).lower() or "429" in str(e):
                wait_time = base_delay * (2 ** attempt) + 5  # Exponential backoff + 5s for rate limit
                logger.warning(f"Rate limit hit. Waiting {wait_time}s before retry {attempt + 1}/{max_retries}")
                time.sleep(wait_time)
            elif "api" in str(e).lower() or "400" in str(e) or "500" in str(e):
                wait_time = base_delay * (2 ** attempt)
                logger.warning(f"API error: {e}. Retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
                time.sleep(wait_time)
            else:
                logger.error(f"Unexpected error: {e}")
                if attempt == max_retries - 1:
                    return f"[ERROR: {str(e)}]"
                time.sleep(base_delay * (2 ** attempt))
    
    return "[ERROR: Max retries exceeded]"


def process_item(item: Dict[str, Any], model_id: str = "qwen/qwen3-32b", cot: bool = True, domain: str = "math") -> Dict[str, Any]:
    """Process a single item and generate answer."""
    logger.info(f"Processing item {item['id']}")
    
    # Format the generic expert prompt
    messages = format_chat_message(item['generic_expert_prompt'], cot, domain)
    
    # Generate answer
    generated_answer = generate_answer(messages, model_id)
    
    # Return required fields
    return {
        "id": item["id"],
        "domain": item["domain"], 
        "generic_expert_prompt": item["generic_expert_prompt"],
        "gold_answer": item["gold_answer"],
        "generated_answer": generated_answer
    }


def load_jsonl(filepath: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def save_json(data: List[Dict[str, Any]], filepath: str) -> None:
    """Save data to JSON file."""
    Path(filepath).parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def load_existing_results(filepath: str, redo: bool = False) -> tuple[List[Dict[str, Any]], set]:
    """Load existing results and return results list and processed IDs set.
    If redo=True, exclude IDs that have empty generated_answer from processed IDs."""
    results = []
    processed_ids = set()
    
    if os.path.exists(filepath):
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                results = json.load(f)
            
            for result in results:
                if redo:
                    # Only consider it processed if generated_answer is not empty
                    if result.get('generated_answer', '').strip() != '':
                        processed_ids.add(result['id'])
                else:
                    # Original behavior: all loaded results are considered processed
                    processed_ids.add(result['id'])
                    
        except (json.JSONDecodeError, FileNotFoundError):
            results = []
            processed_ids = set()
    
    return results, processed_ids


def main():
    parser = argparse.ArgumentParser(description="Simple Qwen inference on generic expert prompts")
    parser.add_argument("--input", required=True, help="Input JSONL file path")
    parser.add_argument("--output", required=True, help="Output JSON file path")
    parser.add_argument("--model_id", default="qwen/qwen3-32b", help="Qwen model ID")
    parser.add_argument("--cot", action="store_true", help="Use chain-of-thought prompting (default: no CoT)")
    parser.add_argument("--domain", default="math", help="Domain for the problems (math/medical/psychology/legal)")
    parser.add_argument("--redo", action="store_true", help="Regenerate items with empty generated_answer")
    args = parser.parse_args()
    
    # Load input data
    logger.info(f"Loading data from {args.input}")
    input_data = load_jsonl(args.input)
    logger.info(f"Loaded {len(input_data)} items")
    
    # Load existing results to resume from where we left off
    results, processed_ids = load_existing_results(args.output, args.redo)
    logger.info(f"Found {len(results)} existing results")
    if args.redo:
        empty_count = len(results) - len(processed_ids)
        logger.info(f"Redo mode: Found {empty_count} items with empty generated_answer to regenerate")
    logger.info(f"Will skip {len(processed_ids)} already processed items")
    
    # Process remaining items
    for item in tqdm(input_data, desc="Processing items"):
        # Skip if already processed (and not empty in redo mode)
        if item['id'] in processed_ids:
            logger.info(f"Skipping already processed item {item['id']}")
            continue
        
        # Check if this is a regeneration (item exists in results but with empty answer)
        existing_item_index = None
        if args.redo:
            for i, existing_result in enumerate(results):
                if existing_result['id'] == item['id']:
                    existing_item_index = i
                    break
        
        result = process_item(item, args.model_id, args.cot, args.domain)
        
        if existing_item_index is not None:
            # Replace existing item with empty answer
            results[existing_item_index] = result
            logger.info(f"Regenerated item {item['id']} (replaced empty answer)")
        else:
            # Add new result
            results.append(result)
            logger.info(f"Processed new item {item['id']}")
        
        # Save after each generation
        save_json(results, args.output)
    
    logger.info(f"Processing complete! Final results saved to {args.output}")


if __name__ == "__main__":
    main()