#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
from typing import Dict, Any, List
from collections import Counter



# Generic expert templates for direct answer questions
GENERIC_EXPERT_TEMPLATES = {
    "math": """You are not a mathematician with expertise in problem solving and logical reasoning.
Think carefully through the problem and provide your final answer.

Question:
{question}""",
    
    "medical": """You are not a physician with deep knowledge across all medical specialties.
Think carefully through the problem and provide your final answer.

Question:
{question}""",
    
    "psychology": """You are not a distinguished psychologist with broad expertise in cognitive, developmental, and clinical psychology.
Think carefully through the problem and provide your final answer.

Question:
{question}""",
    
    "legal": """You are not an experienced legal scholar with deep knowledge of U.S. law and legal reasoning.
Think carefully through the problem and provide your final answer.

Question:
{question}"""
}


def load_jsonl(filepath: str) -> list[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def save_jsonl(data: list[Dict[str, Any]], filepath: str) -> None:
    """Save data to JSONL file."""
    Path(filepath).parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def build_prompts(item: Dict[str, Any]) -> Dict[str, Any]:
    """Build generic expert prompt for a single item."""
    question = item["question"]
    domain = item["domain"]
    
    # Generate generic expert prompt for the domain
    generic_template = GENERIC_EXPERT_TEMPLATES.get(domain, GENERIC_EXPERT_TEMPLATES["math"])
    generic_expert_prompt = generic_template.format(question=question)
    
    # Build output item with simplified schema
    output_item = {
        "id": item["id"],
        "domain": item["domain"],
        "question": item["question"],
        "gold_answer": item["gold_answer"],
        "generic_expert_prompt": generic_expert_prompt
    }
    
    return output_item


def print_character_count_table(items: list[Dict[str, Any]]) -> None:
    """Print character count table for first 3 items."""
    print("\nCharacter count table for first 3 items:")
    
    for i, item in enumerate(items[:3]):
        domain = item["domain"]
        
        print(f"\nItem {i} (Domain: {domain}):")
        print(f"  Generic Expert: {len(item['generic_expert_prompt'])} chars")


def print_preview(items: List[Dict[str, Any]]) -> None:
    """Print preview of first 2 rows."""
    print(f"\nPreview of first 2 rows:")
    
    for i, item in enumerate(items[:2]):
        print(f"\n--- Row {i+1} ---")
        print(f"Question: {item['question']}")
        print(f"Gold Answer: {item['gold_answer']}")
        
        # Show first 160 chars of generic expert prompt
        print(f"Generic Expert prompt (first 160 chars): {item['generic_expert_prompt'][:160]}...")


def main():
    parser = argparse.ArgumentParser(description="Build prompt families from direct answer test data")
    parser.add_argument("--in", required=True, dest="input_file", help="Input JSONL file path")
    parser.add_argument("--out", required=True, help="Output JSONL file path")
    parser.add_argument("--limit", type=int, help="Limit number of items to process")
    args = parser.parse_args()
    
    # Load input data
    print(f"Loading data from {args.input_file}...")
    input_data = load_jsonl(args.input_file)
    print(f"Loaded {len(input_data)} items")
    
    # Apply limit if specified
    if args.limit and len(input_data) > args.limit:
        print(f"Limiting to first {args.limit} items")
        input_data = input_data[:args.limit]
    
    # Validate all items
    print("Validating data...")
    for i, item in enumerate(input_data):
        # Check required fields
        required_fields = ["id", "domain", "question", "gold_answer"]
        for field in required_fields:
            if field not in item:
                raise ValueError(f"Item {item.get('id', 'unknown')} missing field: {field}")
            if not item[field] and field != "id":  # id can be 0
                raise ValueError(f"Item {item.get('id', 'unknown')} empty field: {field}")
    
    print(f"Validation passed for {len(input_data)} items")
    
    # Build prompts for all items
    print("Building prompts...")
    output_data = []
    for item in input_data:
        try:
            output_item = build_prompts(item)
            output_data.append(output_item)
        except Exception as e:
            raise ValueError(f"Failed to build prompts for item {item['id']}: {e}")
    
    print(f"Built prompts for {len(output_data)} items")
    
    # Print preview
    print_preview(output_data)
    
    # Print final statistics
    domain_counts = Counter(item["domain"] for item in output_data)
    print(f"\nFinal statistics:")
    print(f"Total items processed: {len(output_data)}")
    print(f"Items by domain: {dict(domain_counts)}")
    
    # Save output data
    print(f"\nSaving prompts to {args.out}...")
    save_jsonl(output_data, args.out)
    print(f"Successfully saved {len(output_data)} items to {args.out}")


if __name__ == "__main__":
    main()