#!/usr/bin/env python3
"""
Gemini inference script for multi-domain LLM evaluation pipeline.
Uses OpenRouter API to access Google Gemini models for generating responses to different prompt variants.
"""
import argparse
import json
import logging
import time
import os
from pathlib import Path
from typing import Dict, Any, List
from tqdm import tqdm
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Gemini API configuration
GEMINI_KEY = "OPENROUTER_API_KEY"  # Replace with your actual OpenRouter API key or set as environment variable
STOP_AFTER_ATTEMPT = 10

# Initialize OpenRouter client for Gemini
client = OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key=GEMINI_KEY,
)

@retry(stop=stop_after_attempt(STOP_AFTER_ATTEMPT),
        wait=wait_chain(*[wait_fixed(3) for i in range(2)] +
                       [wait_fixed(5) for i in range(1)]))

def format_chat_message(user_prompt: str, is_persona: bool = False, use_cot: bool = True, domain: str = "") -> List[Dict[str, str]]:
    """Format message for Gemini Chat API."""
    if use_cot:
        if is_persona:
            system_content = "Please strictly follow the instructions and adhere to the role given to you. Avoid over-thinking and ANSWER IN 100 TOKENS. Prioritize giving the <final> answer. Solve the following problem step-by-step and END WITH A LINE: \"Answer: <final>\" <final> should be only the correct choice(1, 2, 3, 4, etc.) ANSWER IN 100 TOKENS! ALWAYS END WITH A LINE: \"Answer: <final>\""
        else:
            system_content = "Please strictly follow the instructions. Avoid over-thinking and ANSWER IN 100 TOKENS. Prioritize giving the <final> answer. Solve the following problem step-by-step and END WITH A LINE: \"Answer: <final>\" <final> should be only the correct choice(1, 2, 3, 4, etc.) ANSWER IN 100 TOKENS! ALWAYS END WITH A LINE: \"Answer: <final>\""   
    else:
        # Domain-specific prompts for no-CoT mode
        if domain.lower() == "math":
            if is_persona:
                system_content = "You are given a role to follow, but you must ONLY provide the final numerical answer. Do not show any work, reasoning, calculations, or explanation. Respond with just the number."
            else:
                system_content = "Provide ONLY the final numerical answer. Do not show any work, reasoning, calculations, or explanation. Respond with just the number."
        else:
            # For non-math domains (medical, psychology, legal, etc.) - typically multiple choice
            if is_persona:
                system_content = "You are given a role to follow, but you must ONLY provide the final answer letter. Do not show any work, reasoning, or explanation. Respond with just the correct choice(1, 2, 3, 4, etc.)."
            else:
                system_content = "Provide ONLY the final answer letter. Do not show any work, reasoning, or explanation. Respond with just the correct choice(1, 2, 3, 4, etc.)."
    
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": user_prompt}
    ]
    
    return messages


def generate_text(messages: List[Dict[str, str]], model_id: str = "google/gemini-2.5-flash", 
                 max_tokens: int = 512, variant_name: str = "") -> str:
    """Generate text using Gemini via OpenRouter API with retry logic."""
    print(f"  Generating {variant_name}..." if variant_name else "  Generating...")
    
    max_retries = 3
    base_delay = 1
    
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                extra_body={},
                model=model_id,
                messages=messages,
                max_tokens=500,     # change this when needed. 150 is just a placeholder here.
                temperature=0.0,
                stop=["# END"] 
            )
            
            generated_text = response.choices[0].message.content.strip()
            print(generated_text)
            return generated_text
            
        except Exception as e:
            if "rate_limit" in str(e).lower() or "429" in str(e):
                wait_time = base_delay * (2 ** attempt) + 5  # Exponential backoff + 5s for rate limit
                logger.warning(f"Rate limit hit for {variant_name}. Waiting {wait_time}s before retry {attempt + 1}/{max_retries}")
                time.sleep(wait_time)
            elif "api" in str(e).lower() or "400" in str(e) or "500" in str(e):
                wait_time = base_delay * (2 ** attempt)
                logger.warning(f"API error for {variant_name}: {e}. Retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
                time.sleep(wait_time)
            else:
                logger.error(f"Unexpected error for {variant_name}: {e}")
                if attempt == max_retries - 1:
                    return f"[ERROR: {str(e)}]"
                time.sleep(base_delay * (2 ** attempt))
    
    
    return "[ERROR: Max retries exceeded]"


def process_single_item(item: Dict[str, Any], model_id: str = "google/gemini-2.5-flash", 
                       max_tokens: int = 512, use_cot: bool = True, domain: str = "") -> Dict[str, Any]:
    """Process a single item and generate all variants."""
    print(f"\n=== Processing item {item['id']} ===")
    
    start_time = time.time()
    
    # Process baseline variant
    print("Processing baseline variant...")
    messages = format_chat_message(item['baseline_prompt'], is_persona=False, use_cot=use_cot, domain=domain)
    baseline_output = generate_text(messages, model_id, max_tokens, "baseline")
    
    # Process primed variant
    print("Processing primed variant...")
    messages = format_chat_message(item['primed_prompt'], is_persona=False, use_cot=use_cot, domain=domain)
    primed_output = generate_text(messages, model_id, max_tokens, "primed")
    
    # Process persona variants
    persona_outputs = {}
    print("Processing persona variants...")
    for persona_name, persona_prompt in tqdm(item['persona_prompts'].items(), desc="Personas", leave=False):
        messages = format_chat_message(persona_prompt, is_persona=True, use_cot=use_cot, domain=domain)
        persona_outputs[persona_name] = generate_text(messages, model_id, max_tokens, persona_name)
        
        # Add small delay between API calls to be respectful
        time.sleep(0.1)
    
    processing_time = time.time() - start_time
    
    logger.info(f"Item {item['id']} processed in {processing_time:.2f}s")
    
    # Build output item
    output_item = {
        "id": item["id"],
        "question": item["question"],
        "gold_answer": item["gold_answer"],
        "domain": item["domain"],
        "baseline_prompt": item["baseline_prompt"],
        "primed_prompt": item["primed_prompt"],
        "persona_prompts": item["persona_prompts"],
        "baseline_output": baseline_output,
        "primed_output": primed_output,
        "persona_outputs": persona_outputs
    }
    
    return output_item


def load_jsonl(filepath: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def get_processed_ids(output_file: str) -> set:
    """Get set of IDs that have already been processed from output file."""
    processed_ids = set()
    if Path(output_file).exists():
        try:
            with open(output_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        item = json.loads(line.strip())
                        processed_ids.add(item['id'])
        except Exception as e:
            logger.warning(f"Error reading existing output file: {e}")
    return processed_ids


def process_domain_data(input_data: List[Dict[str, Any]], output_file: str, 
                       model_id: str = "google/gemini-2.5-flash", max_tokens: int = 512, use_cot: bool = True, 
                       domain: str = "", resume: bool = False) -> None:
    """Process data for a specific domain."""
    if not input_data:
        logger.warning(f"No data to process for {domain}")
        return
        
    domain_prefix = f"{domain} " if domain else ""
    logger.info(f"Processing {domain_prefix}data with {len(input_data)} items using {model_id}...")
    
    # Prepare output file
    Path(output_file).parent.mkdir(parents=True, exist_ok=True)
    
    # Check for resume functionality
    processed_ids = set()
    if resume:
        processed_ids = get_processed_ids(output_file)
        if processed_ids:
            logger.info(f"Resume mode: Found {len(processed_ids)} already processed items")
            logger.info(f"Will skip items with IDs: {sorted(processed_ids)}")
        else:
            logger.info("Resume mode: No existing output found, starting from beginning")
    else:
        # Clear the file if not resuming
        with open(output_file, 'w') as f:
            pass
    
    # Filter out already processed items if resuming
    items_to_process = input_data
    if resume and processed_ids:
        items_to_process = [item for item in input_data if item['id'] not in processed_ids]
        logger.info(f"Resume mode: {len(items_to_process)}/{len(input_data)} items remaining to process")
    
    processed_count = 0
    total_start_time = time.time()
    
    for item in items_to_process:
        result = process_single_item(item, model_id, max_tokens, use_cot=use_cot, domain=domain)
        
        # Append result immediately to file
        with open(output_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')
        
        processed_count += 1
        if processed_count % 5 == 0:
            elapsed = time.time() - total_start_time
            avg_time = elapsed / processed_count
            remaining = (len(input_data) - processed_count) * avg_time
            logger.info(f"Processed {processed_count}/{len(input_data)} {domain_prefix}items. ETA: {remaining/60:.1f}m")
    
    total_time = time.time() - total_start_time
    logger.info(f"Successfully processed and saved {processed_count} {domain_prefix}items to {output_file}")
    logger.info(f"Total processing time: {total_time/60:.2f} minutes ({total_time/processed_count:.2f}s per item)")


def main():
    parser = argparse.ArgumentParser(description="Run Gemini inference on prompt variants")
    parser.add_argument("--in", default="data", dest="input_file", help="Input file path")
    parser.add_argument("--out", default="data", help="Output directory path")
    parser.add_argument("--domain", help="Specific domain to process (math/medical/psychology/legal)")
    parser.add_argument("--model_id", default="openai/gpt-4.1", help="Gemini model ID (default: google/gemini-2.5-flash)")
    parser.add_argument("--api_key", help="OpenRouter API key (can also use environment variable)")
    parser.add_argument("--max_tokens", type=int, default=512, help="Maximum tokens to generate (default: 512)")
    parser.add_argument("--cot", action="store_true", help="Use chain-of-thought prompting (default: no CoT)")
    parser.add_argument("--resume", action="store_true", help="Resume from where generation left off (skip already processed items)")
    args = parser.parse_args()
    
    # Determine input files and output paths
    input_path = Path(args.input_file)
    cot_suffix = "_cot" if args.cot else "_no_cot"
    
    if not input_path.exists():
        logger.error(f"Input file {args.input_file} does not exist")
        return 1

    # Create output directory and construct output file path
    os.makedirs(args.out, exist_ok=True)
    model_name = args.model_id.split('/')[-1].replace('-', '_')
    output_path = Path(args.out) / f"generations_{model_name}{cot_suffix}.jsonl"
    
    # Input is a single file
    logger.info(f"Loading data from {args.input_file}")
    input_data = load_jsonl(args.input_file)
    logger.info(f"Loaded {len(input_data)} items")
    
    # Filter by domain if specified
    if args.domain:
        input_data = [item for item in input_data if item.get('domain') == args.domain]
        logger.info(f"Filtered to {len(input_data)} items for domain: {args.domain}")

    # Process single file
    process_domain_data(input_data, str(output_path), args.model_id, args.max_tokens, args.cot, args.domain or "", args.resume)
    
    logger.info("Gemini inference pipeline completed successfully!")
    return 0


if __name__ == "__main__":
    exit(main())


# Usage Examples:
#   # Run inference on all domains with CoT
#   python infer_gemini.py --in data --out data --cot

#   # Run on specific domain without CoT
#   python infer_gemini.py --domain commonsense --cot

#   # Use different model
#   python infer_gemini.py --model_id google/gemini-2.5-flash --cot

#   # Single file processing
#   python infer_gemini.py --in data/commonsense/dataset/prompts.jsonl --out results --cot

#   Output Structure:
#   - Follows same pattern: data/{domain}/generation/generations_{model_name}_{cot/no_cot}.jsonl
#   - Same JSON structure with baseline_output, primed_output, and persona_outputs