#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
from typing import Dict, Any, List
from collections import Counter


# Prompt templates with exact formatting preserved
BASELINE_TEMPLATE = """Solve the following problem. Show your steps and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""

# Domain-specific primed templatesa
PRIMED_TEMPLATES = {
    "math": """This is a mathematics question. Retrieve and apply relevant math knowledge and problem-solving steps. Show your reasoning and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
    
    "medical": """This is a medical question. Apply established medical science and clinical reasoning to provide an accurate answer. Choose exactly one option and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
    
    "psychology": """This is a psychology question. Use evidence-based psychological science and research to determine the correct answer. Choose exactly one option and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
    
    "legal": """This is a legal question. Apply established legal principles and precedent to provide the correct answer. Choose exactly one option and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
    
    "commonsense": """This is a commonsense reasoning question. Please retrieve and apply your broad everyday knowledge and reasoning before answering. Choose exactly one option and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""
}

# Domain-specific persona templates
PERSONA_TEMPLATES = {
    "math": {
        "generic_math_expert": """You are a brilliant mathematician with expertise in problem solving and logical reasoning.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "euclid": """You are Euclid, a Greek mathematician known as the father of geometry.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "terence_tao": """You are Terence Tao, a Fields Medal–winning mathematician known for deep insights across many fields of mathematics.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""
    },
    
    "medical": {
        "generic_medical_expert": """You are a highly experienced physician with deep knowledge across all medical specialties.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "historical_hippocrates": """You are Hippocrates, the ancient Greek physician known as the "Father of Medicine."
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "modern_atul_gawande": """You are Dr. Atul Gawande, a renowned surgeon, writer, and public health researcher.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""
    },
    
    "psychology": {
        "generic_psychology_expert": """You are a distinguished psychologist with broad expertise in cognitive, developmental, and clinical psychology.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "historical_sigmund_freud": """You are Sigmund Freud, the founder of psychoanalysis.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "modern_steven_pinker": """You are Steven Pinker, a cognitive psychologist and linguist known for research on language and the mind.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""
    },
    
    "legal": {
        "generic_legal_expert": """You are an experienced legal scholar with deep knowledge of U.S. law and legal reasoning.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "historical_oliver_wendell_holmes_jr": """You are Oliver Wendell Holmes Jr., former Associate Justice of the U.S. Supreme Court, renowned for your influential opinions.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
        
        "modern_ruth_bader_ginsburg": """You are Ruth Bader Ginsburg, former Associate Justice of the U.S. Supreme Court, celebrated for precise legal analysis and attention to precedent.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""
    }
}


def load_jsonl(filepath: str) -> list[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def save_jsonl(data: list[Dict[str, Any]], filepath: str) -> None:
    """Save data to JSONL file."""
    Path(filepath).parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def compute_gold_index(gold_label: str) -> int:
    """Compute gold_index from gold_label (A->1, B->2, etc.)."""
    return 1 + "ABCDEFGHIJKLMNOPQRSTUVWXYZ".index(gold_label)


def format_options(choices: List[str]) -> str:
    """Format choices as numbered options (1..K)."""
    return "\n".join(f"{i+1}) {choice}" for i, choice in enumerate(choices))


def find_modal_choice_count(data: List[Dict[str, Any]]) -> int:
    """Find the most common choice count in the dataset."""
    choice_counts = [len(item["choices"]) for item in data]
    return Counter(choice_counts).most_common(1)[0][0]


def filter_by_choice_count(data: List[Dict[str, Any]], modal_K: int) -> List[Dict[str, Any]]:
    """Filter data to keep only items with modal choice count."""
    original_count = len(data)
    filtered_data = [item for item in data if len(item["choices"]) == modal_K]
    
    # Log filtering results
    counts_by_length = Counter(len(item["choices"]) for item in data)
    print(f"Choice count distribution: {dict(counts_by_length)}")
    print(f"Modal choice count: {modal_K}")
    print(f"Kept {len(filtered_data)}/{original_count} items with {modal_K} choices")
    for length, count in counts_by_length.items():
        if length != modal_K:
            print(f"Dropped {count} items with {length} choices")
    
    return filtered_data


def build_prompts(item: Dict[str, Any], modal_K: int) -> Dict[str, Any]:
    """Build all prompt variants for a single item."""
    question = item["question"]
    domain = item["domain"]
    choices = item["choices"]
    gold_label = item["gold_label"]
    
    # Validate input
    if len(choices) != modal_K:
        raise ValueError(f"Item {item['id']} has {len(choices)} choices, expected {modal_K}")
    
    # Compute gold_index
    gold_index = compute_gold_index(gold_label)
    if not (1 <= gold_index <= modal_K):
        raise ValueError(f"Item {item['id']} gold_index {gold_index} not in range [1, {modal_K}]")
    
    # Format options for templates
    options = format_options(choices)
    
    # Generate baseline prompt (domain-agnostic)
    baseline_prompt = BASELINE_TEMPLATE.format(
        question=question, 
        options=options, 
        K=modal_K
    )
    
    # Generate domain-specific primed prompt
    primed_template = PRIMED_TEMPLATES.get(domain, PRIMED_TEMPLATES["math"])
    primed_prompt = primed_template.format(
        question=question, 
        options=options, 
        K=modal_K
    )
    
    # Generate domain-specific persona prompts
    persona_prompts = {}
    
    if domain == "commonsense":
        # For commonsense, use all personas from all domains
        for domain_key, domain_personas in PERSONA_TEMPLATES.items():
            for persona_name, template in domain_personas.items():
                persona_prompts[persona_name] = template.format(
                    question=question, 
                    options=options, 
                    K=modal_K
                )
    else:
        # For other domains, use their specific personas
        domain_personas = PERSONA_TEMPLATES.get(domain, PERSONA_TEMPLATES["math"])
        for persona_name, template in domain_personas.items():
            persona_prompts[persona_name] = template.format(
                question=question, 
                options=options, 
                K=modal_K
            )
    
    # Build output item with new schema
    output_item = {
        "id": item["id"],
        "domain": item["domain"],
        "question": item["question"],
        "choices": item["choices"],  # pass-through
        "gold_label": item["gold_label"],  # pass-through for compatibility
        "gold_index": gold_index,  # NEW: derived from gold_label
        "gold_answer": item["gold_answer"],  # pass-through
        "baseline_prompt": baseline_prompt,
        "primed_prompt": primed_prompt,
        "persona_prompts": persona_prompts
    }
    
    return output_item


def print_character_count_table(items: list[Dict[str, Any]]) -> None:
    """Print character count table for first 3 items."""
    print("\nCharacter count table for first 3 items:")
    
    for i, item in enumerate(items[:3]):
        domain = item["domain"]
        persona_names = list(item["persona_prompts"].keys())
        
        print(f"\nItem {i} (Domain: {domain}):")
        print(f"  Baseline: {len(item['baseline_prompt'])} chars")
        print(f"  Primed: {len(item['primed_prompt'])} chars")
        print("  Personas:")
        for persona_name in persona_names:
            persona_len = len(item["persona_prompts"][persona_name])
            print(f"    {persona_name}: {persona_len} chars")


def print_preview(items: List[Dict[str, Any]], modal_K: int) -> None:
    """Print preview of first 2 rows with formatted options."""
    print(f"\nPreview of first 2 rows (modal_K={modal_K}):")
    
    for i, item in enumerate(items[:2]):
        print(f"\n--- Row {i+1} ---")
        print(f"Question: {item['question']}")
        
        # Show rendered options
        options = format_options(item["choices"])
        print("Options:")
        print(options)
        
        # Show first 160 chars of each prompt type
        print(f"Baseline prompt (first 160 chars): {item['baseline_prompt'][:160]}...")
        print(f"Primed prompt (first 160 chars): {item['primed_prompt'][:160]}...")
        
        # Show one persona prompt
        if item['persona_prompts']:
            first_persona = list(item['persona_prompts'].keys())[0]
            persona_prompt = item['persona_prompts'][first_persona]
            print(f"{first_persona} prompt (first 160 chars): {persona_prompt[:160]}...")


def main():
    parser = argparse.ArgumentParser(description="Build prompt families from multi-domain test data")
    parser.add_argument("--in", required=True, dest="input_file", help="Input JSONL file path")
    parser.add_argument("--out", required=True, help="Output JSONL file path")
    parser.add_argument("--limit", type=int, help="Limit number of items to process (for medical domain)")
    args = parser.parse_args()
    
    # Load input data
    print(f"Loading data from {args.input_file}...")
    input_data = load_jsonl(args.input_file)
    print(f"Loaded {len(input_data)} items")
    
    # Apply limit if specified (typically for medical domain)
    if args.limit and len(input_data) > args.limit:
        print(f"Limiting to first {args.limit} items")
        input_data = input_data[:args.limit]
    
    # Find modal choice count
    modal_K = find_modal_choice_count(input_data)
    
    # Error if modal_K < 3
    if modal_K < 3:
        raise ValueError(f"Modal choice count {modal_K} is less than 3")
    
    # Filter data to keep only items with modal choice count
    filtered_data = filter_by_choice_count(input_data, modal_K)
    
    if not filtered_data:
        raise ValueError("No items remaining after filtering by choice count")
    
    # Validate all remaining items
    print("Validating filtered data...")
    for i, item in enumerate(filtered_data):
        # Check required fields
        required_fields = ["id", "domain", "question", "choices", "gold_label", "gold_answer"]
        for field in required_fields:
            if field not in item:
                raise ValueError(f"Item {item.get('id', 'unknown')} missing field: {field}")
            if field == "id":
                item[field] = i
            elif not item[field]:
                raise ValueError(f"Item {item.get('id', 'unknown')} empty field: {field}")
        
        # Check choices are non-empty strings
        choices = item["choices"]
        if len(choices) != modal_K:
            raise ValueError(f"Item {item['id']} has {len(choices)} choices, expected {modal_K}")
        
        for i, choice in enumerate(choices):
            if not isinstance(choice, str) or not choice.strip():
                raise ValueError(f"Item {item['id']} choice {i+1} is not a non-empty string")
        
        # Validate gold_label is valid for modal_K
        gold_label = item["gold_label"]
        try:
            gold_index = compute_gold_index(gold_label)
            if not (1 <= gold_index <= modal_K):
                raise ValueError(f"Item {item['id']} gold_index {gold_index} not in range [1, {modal_K}]")
        except (ValueError, IndexError) as e:
            raise ValueError(f"Item {item['id']} has invalid gold_label '{gold_label}': {e}")
    
    print(f"Validation passed for {len(filtered_data)} items")
    
    # Build prompts for all filtered items
    print("Building prompts...")
    output_data = []
    for item in filtered_data:
        try:
            output_item = build_prompts(item, modal_K)
            output_data.append(output_item)
        except Exception as e:
            raise ValueError(f"Failed to build prompts for item {item['id']}: {e}")
    
    print(f"Built prompts for {len(output_data)} items")
    
    # Print preview
    print_preview(output_data, modal_K)
    
    # Print final statistics
    domain_counts = Counter(item["domain"] for item in output_data)
    print(f"\nFinal statistics:")
    print(f"Modal choice count (K): {modal_K}")
    print(f"Total items kept: {len(output_data)}")
    print(f"Items by domain: {dict(domain_counts)}")
    
    # Save output data
    print(f"\nSaving prompts to {args.out}...")
    save_jsonl(output_data, args.out)
    print(f"Successfully saved {len(output_data)} items to {args.out}")


if __name__ == "__main__":
    main()