#!/usr/bin/env python3
"""
Gemini cross-domain inference script for multi-domain LLM evaluation pipeline.
Uses OpenRouter API to access Google Gemini models for generating responses to different prompt variants in cross-domain settings.
"""
import argparse
import json
import logging
import time
import os
from pathlib import Path
from typing import Dict, Any, List
from tqdm import tqdm
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# Gemini API configuration
GEMINI_KEY = "OPENROUTER_API_KEY"  # Set your Gemini API key in environment variable
STOP_AFTER_ATTEMPT = 10

# Initialize OpenRouter client for Gemini
client = OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key=GEMINI_KEY,
)


def extract_model_name(model_id: str) -> str:
    """Extract a clean model name from the model ID for filename purposes."""
    return model_id.replace("/", "_").replace(".", "_").replace("-", "_")


def format_chat_message(user_prompt: str, is_persona: bool = False, use_cot: bool = True, domain: str = "") -> List[Dict[str, str]]:
    """Format message for Gemini Chat API."""
    if use_cot:
        if is_persona:
            system_content = "Please strictly follow the instructions and adhere to the role given to you. DO NOT REASONING FOR TOO LONG AND ANSWER IN 100 TOKENS. Prioritize giving the <final> answer. Solve the following problem step-by-step and END WITH A LINE: \"Answer: <final>\" <final> should be only the correct choice(1, 2, 3, 4, etc.) ANSWER IN 100 TOKENS! ALWAYS END WITH A LINE: \"Answer: <final>\""
        else:
            system_content = "Please strictly follow the instructions. DO NOT REASONING FOR TOO LONG AND ANSWER IN 100 TOKENS. Prioritize giving the <final> answer. Solve the following problem step-by-step and END WITH A LINE: \"Answer: <final>\" <final> should be only the correct choice(1, 2, 3, 4, etc.) ANSWER IN 100 TOKENS! ALWAYS END WITH A LINE: \"Answer: <final>\""
    else:
        # Domain-specific prompts for no-CoT mode
        if domain.lower() == "math":
            if is_persona:
                system_content = "You are given a role to follow, but you must ONLY provide the final numerical answer. Do not show any work, reasoning, calculations, or explanation. Respond with just the number."
            else:
                system_content = "Provide ONLY the final numerical answer. Do not show any work, reasoning, calculations, or explanation. Respond with just the number."
        else:
             # For non-math domains (medical, psychology, legal, etc.) - typically multiple choice
            if is_persona:
                system_content = "You are given a role to follow, but you must ONLY provide the final answer letter. Do not show any work, reasoning, or explanation. Respond with just the correct choice(1, 2, 3, 4, etc.)."
            else:
                system_content = "Provide ONLY the final answer letter. Do not show any work, reasoning, or explanation. Respond with just the correct choice(1, 2, 3, 4, etc.)."
    
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": "/no_think\n" + user_prompt}
    ]
    
    return messages


def generate_text(messages: List[Dict[str, str]], model_id: str = "qwen/qwen3-32b", 
                 max_tokens: int = 512, variant_name: str = "") -> str:
    """Generate text using Gemini via OpenRouter API with retry logic."""
    print(f"  Generating {variant_name}..." if variant_name else "  Generating...")
    
    max_retries = 3
    base_delay = 1
    
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                extra_body={},
                model=model_id,
                messages=messages,
                max_tokens=500,
                temperature=0.0,
                stop=["# END"] 
            )
            generated_text = response.choices[0].message.content.strip()
            print(generated_text)
            return generated_text
            
        except Exception as e:
            error_msg = str(e).lower()
            if "rate limit" in error_msg or "quota" in error_msg:
                wait_time = base_delay * (2 ** attempt) + 5  # Exponential backoff + 5s for rate limit
                logger.warning(f"Rate limit hit for {variant_name}. Waiting {wait_time}s before retry {attempt + 1}/{max_retries}")
                time.sleep(wait_time)
            elif "api" in error_msg:
                wait_time = base_delay * (2 ** attempt)
                logger.warning(f"API error for {variant_name}: {e}. Retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
                time.sleep(wait_time)
            else:
                logger.error(f"Unexpected error for {variant_name}: {e}")
                if attempt == max_retries - 1:
                    return f"[ERROR: {str(e)}]"
                time.sleep(base_delay * (2 ** attempt))
    
    return "[ERROR: Max retries exceeded]"


def process_single_item(item: Dict[str, Any], model_id: str, use_cot: bool = True, domain: str = "", output_file: str = "") -> Dict[str, Any]:
    """Process a single item and generate all variants."""
    print(f"\n=== Processing item {item['id']} ===")
    
    start_time = time.time()
    
    # Initialize result structure (removed baseline_output and primed_output)
    result = {
        "id": item["id"],
        "question": item["question"],
        "gold_answer": item["gold_answer"],
        "domain": item["domain"],
        "baseline_prompt": item["baseline_prompt"],
        "primed_prompt": item["primed_prompt"],
        "persona_prompts": item["persona_prompts"],
        "persona_outputs": {}
    }
    
    # Add cross-domain info if present
    if "persona_domain" in item:
        result["persona_domain"] = item["persona_domain"]
    if "target_domain" in item:
        result["target_domain"] = item["target_domain"]
    
    # Process persona variants only
    print("Processing persona variants...")
    for persona_name, persona_prompt in tqdm(item['persona_prompts'].items(), desc="Personas", leave=False):
        messages = format_chat_message(persona_prompt, is_persona=True, use_cot=use_cot, domain=domain)
        output = generate_text(messages, model_id, variant_name=persona_name)
        result["persona_outputs"][persona_name] = output
    
    end_time = time.time()
    print(f"Item {item['id']} completed in {end_time - start_time:.2f}s")
    
    return result


def load_jsonl(filepath: str) -> List[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def load_existing_results(output_file: str, redo: bool = False) -> tuple[List[Dict[str, Any]], set]:
    """Load existing results and return results list and processed IDs set.
    If redo=True, exclude IDs that have empty persona_outputs from processed IDs."""
    results = []
    processed_ids = set()
    
    if Path(output_file).exists():
        try:
            with open(output_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        item = json.loads(line.strip())
                        results.append(item)
                        
                        if redo:
                            # Check if this item has any empty persona outputs
                            has_empty = False
                            persona_outputs = item.get('persona_outputs', {})
                            for persona, output in persona_outputs.items():
                                if output.strip() == '':
                                    has_empty = True
                                    break
                            
                            # Only add to processed_ids if it has NO empty outputs
                            if not has_empty:
                                processed_ids.add(item['id'])
                        else:
                            # Original behavior: add all processed IDs
                            processed_ids.add(item['id'])
        except Exception as e:
            logger.warning(f"Could not read existing output file {output_file}: {e}")
    
    return results, processed_ids


def save_jsonl(data: List[Dict[str, Any]], filepath: str) -> None:
    """Save data to JSONL file."""
    Path(filepath).parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def process_cross_domain_data(input_file: str, output_file: str, model_id: str, use_cot: bool, domain: str = "", resume: bool = False, redo: bool = False) -> None:
    """Process cross-domain data."""
    logger.info(f"Loading data from {input_file}")
    input_data = load_jsonl(input_file)
    
    if not input_data:
        logger.warning(f"No data found in {input_file}")
        return
        
    logger.info(f"Processing cross-domain data with {len(input_data)} items...")
    
    # Prepare output file
    Path(output_file).parent.mkdir(parents=True, exist_ok=True)
    
    # Load existing results and get processed IDs
    results = []
    processed_ids = set()
    
    if resume:
        # Load existing results and get processed IDs (excluding empty outputs in redo mode)
        results, processed_ids = load_existing_results(output_file, redo)
        logger.info(f"Resume mode: Found {len(results)} existing results")
        if redo:
            empty_count = len(results) - len(processed_ids)
            logger.info(f"Redo mode: Found {empty_count} items with empty persona_outputs to regenerate")
        logger.info(f"Will skip {len(processed_ids)} already processed items")
    else:
        # Start fresh - clear results
        results = []
        processed_ids = set()
    
    # Filter out already processed items if resuming
    items_to_process = input_data
    if resume and processed_ids:
        items_to_process = [item for item in input_data if item['id'] not in processed_ids]
        logger.info(f"Resume mode: {len(items_to_process)}/{len(input_data)} items remaining to process")
    
    if not items_to_process:
        logger.info("All items already processed!")
        return
    
    processed_count = 0
    total_items = len(items_to_process)
    
    for item in items_to_process:
        logger.info(f"Processing item {item['id']} ({processed_count + 1}/{total_items})")
        
        # Check if this is a regeneration (item exists in results but with empty outputs)
        existing_item_index = None
        if redo and resume:
            for i, existing_result in enumerate(results):
                if existing_result['id'] == item['id']:
                    existing_item_index = i
                    break
        
        result = process_single_item(item, model_id, use_cot=use_cot, domain=domain, output_file="")
        
        if existing_item_index is not None:
            # Replace existing item with empty outputs
            results[existing_item_index] = result
            logger.info(f"Regenerated item {item['id']} (replaced empty persona_outputs)")
        else:
            # Add new result
            results.append(result)
            logger.info(f"Processed new item {item['id']}")
        
        # Save entire results list to file after each item
        save_jsonl(results, output_file)
        
        processed_count += 1
        if processed_count % 10 == 0:
            logger.info(f"Processed {processed_count}/{total_items} items")
    
    logger.info(f"Successfully processed and saved {processed_count} cross-domain items to {output_file}")
    logger.info(f"Total items in output file: {len(results)}")


def main():
    parser = argparse.ArgumentParser(description="Run Gemini inference on cross-domain prompt variants")
    parser.add_argument("--in", required=True, dest="input_file", help="Input JSONL file")
    parser.add_argument("--out", required=True, help="Output directory")
    parser.add_argument("--target-domain", required=True, help="Target domain (legal/medical/psychology)")
    parser.add_argument("--persona-domain", required=True, help="Persona domain (math/medical/psychology/legal)")
    parser.add_argument("--model_id", default="qwen/qwen3-32b", help="Gemini model ID")
    parser.add_argument("--cot", action="store_true", help="Use chain-of-thought prompting (default: no CoT)")
    parser.add_argument("--resume", action="store_true", help="Resume from where generation left off (skip already processed items)")
    parser.add_argument("--redo", action="store_true", help="Regenerate items with empty persona_outputs")
    args = parser.parse_args()
    
    # Extract model name for output filename
    model_name = extract_model_name(args.model_id)
    cot_suffix = "_cot" if args.cot else "_no_cot"
    
    # Build output filename: generations_modelname_cot/no_cot.jsonl
    output_filename = f"generations_{model_name}{cot_suffix}.jsonl"
    output_file = Path(args.out) / output_filename
    
    logger.info(f"Input file: {args.input_file}")
    logger.info(f"Output file: {output_file}")
    logger.info(f"Target domain: {args.target_domain}")
    logger.info(f"Persona domain: {args.persona_domain}")
    logger.info(f"Use CoT: {args.cot}")
    logger.info(f"Model: {args.model_id}")
    
    # Process the cross-domain data
    process_cross_domain_data(args.input_file, str(output_file), args.model_id, args.cot, args.target_domain, args.resume, args.redo)


if __name__ == "__main__":
    main()