#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
from typing import Dict, Any, List
from collections import Counter

#qwen: You are an expert in arithmetic applications who meticulously solves problems by precisely calculating each step.
#qwen: You are a legal expert in tort law, contract law, and evidence with over 15 years of experience providing accurate solutions to complex legal questions.
#qwen: You are a senior internal medicine specialist with expertise in renal physiology, pharmacology, surgical interventions, infectious diseases, and gynecological pathology.
#qwen:You are a clinical psychologist with expertise in human aging, cognitive psychology, psychotherapy methods including Gestalt and structural approaches, and longitudinal developmental research, with a solid understanding of APA ethical guidelines and statistical measurement principles.

#gpt: You are a highly skilled math word problem solver specializing in arithmetic, algebra, unit conversions, and real-world calculations.
#gpt: You are an expert bar exam tutor and law professor specializing in evidence, torts, property, and civil procedure, providing highly accurate and precise legal analysis consistent with prevailing U.S. law.
# Generic expert templates for direct answer questions
#gpt: You are an expert medical board exam coach with specialized knowledge in clinical medicine, pathology, surgery, pharmacology, microbiology, and physiology, trained to provide concise, accurate multiple-choice answers based on advanced medical curricula.
#gpt: You are a board-certified clinical psychologist and university professor specializing in developmental, personality, and cognitive psychology, as well as psychological assessment and psychotherapy interventions.

#llama: You are a research psychologist with expertise in human development, psychology, and education, and you have experience in conducting studies on aging, cognitive development, and training programs.
#llama: You are a medical doctor with 20 years of experience and expertise in internal medicine, including urology, endocrinology, and gastroenterology, with a keen interest in medical research and publications.
#llama: You are a seasoned lawyer with expertise in tort law, land use planning, contract law, and evidence, currently serving as a trial judge in a complex case involving multiple witnesses and disputed claims.
#llama: You are an expert in arithmetic applications who meticulously solves problems by precisely calculating each step.
GENERIC_EXPERT_TEMPLATES = {
    "math": """You are an expert in arithmetic applications who meticulously solves problems by precisely calculating each step.
    Think carefully through the problem and end with a line: "Answer: <N>", where N is a number.

Question:
{question}

""",
    
    "medical": """You are a medical doctor with 20 years of experience and expertise in internal medicine, including urology, endocrinology, and gastroenterology, with a keen interest in medical research and publications.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
    
    "psychology": """You are a clinical psychologist with expertise in human aging, cognitive psychology, psychotherapy methods including Gestalt and structural approaches, and longitudinal developmental research, with a solid understanding of APA ethical guidelines and statistical measurement principles.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}""",
    
    "legal": """You are a legal expert in tort law, contract law, and evidence with over 15 years of experience providing accurate solutions to complex legal questions.
Think carefully through the problem and end with a line: "Answer: <N>", where N is an integer from 1 to {K}.

Question:
{question}

Options:
{options}"""
}


def load_jsonl(filepath: str) -> list[Dict[str, Any]]:
    """Load data from JSONL file."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def save_jsonl(data: list[Dict[str, Any]], filepath: str) -> None:
    """Save data to JSONL file."""
    Path(filepath).parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def compute_gold_index(gold_label: str) -> int:
    """Compute gold_index from gold_label (A->1, B->2, etc.)."""
    return 1 + "ABCDEFGHIJKLMNOPQRSTUVWXYZ".index(gold_label)


def format_options(choices: List[str]) -> str:
    """Format choices as numbered options (1..K)."""
    return "\n".join(f"{i+1}) {choice}" for i, choice in enumerate(choices))


def find_modal_choice_count(data: List[Dict[str, Any]]) -> int:
    """Find the most common choice count in the dataset (excluding math problems)."""
    non_math_items = [item for item in data if item.get("domain") != "math"]
    choice_counts = [len(item["choices"]) for item in non_math_items if "choices" in item]
    return Counter(choice_counts).most_common(1)[0][0] if choice_counts else 4  # default to 4 if no non-math items


def filter_by_choice_count(data: List[Dict[str, Any]], modal_K: int) -> List[Dict[str, Any]]:
    """Filter data to keep only items with modal choice count (skip math problems)."""
    original_count = len(data)
    
    # Separate math problems from others
    math_items = [item for item in data if item.get("domain") == "math"]
    non_math_items = [item for item in data if item.get("domain") != "math"]
    
    # Filter non-math items by choice count
    filtered_non_math = [item for item in non_math_items if len(item.get("choices", [])) == modal_K]
    
    # Combine math items (no filtering) with filtered non-math items
    filtered_data = math_items + filtered_non_math
    
    # Log filtering results
    if non_math_items:
        counts_by_length = Counter(len(item.get("choices", [])) for item in non_math_items)
        print(f"Choice count distribution (non-math): {dict(counts_by_length)}")
        print(f"Modal choice count: {modal_K}")
        print(f"Kept {len(filtered_non_math)}/{len(non_math_items)} non-math items with {modal_K} choices")
        for length, count in counts_by_length.items():
            if length != modal_K:
                print(f"Dropped {count} non-math items with {length} choices")
    
    print(f"Kept all {len(math_items)} math items (no choice filtering)")
    print(f"Total kept: {len(filtered_data)}/{original_count} items")
    
    return filtered_data


def build_prompts(item: Dict[str, Any], modal_K: int) -> Dict[str, Any]:
    """Build generic expert prompt for a single item."""
    question = item["question"]
    domain = item["domain"]
    gold_answer = item["gold_answer"]
    
    # Handle math domain differently (no choices)
    if domain == "math":
        # Math problems don't have choices
        generic_template = GENERIC_EXPERT_TEMPLATES["math"]
        generic_expert_prompt = generic_template.format(question=question)
        
        # Build output item for math
        output_item = {
            "id": item["id"],
            "domain": item["domain"],
            "question": item["question"],
            "gold_answer": item["gold_answer"],
            "generic_expert_prompt": generic_expert_prompt
        }
    else:
        # Handle other domains with choices
        choices = item["choices"]
        gold_label = item["gold_label"]
        
        # Validate input
        if len(choices) != modal_K:
            raise ValueError(f"Item {item['id']} has {len(choices)} choices, expected {modal_K}")
        
        # Compute gold_index
        gold_index = compute_gold_index(gold_label)
        if not (1 <= gold_index <= modal_K):
            raise ValueError(f"Item {item['id']} gold_index {gold_index} not in range [1, {modal_K}]")
        
        # Format options for templates
        options = format_options(choices)
        
        # Generate generic expert prompt for the domain
        generic_template = GENERIC_EXPERT_TEMPLATES.get(domain, GENERIC_EXPERT_TEMPLATES["math"])
        generic_expert_prompt = generic_template.format(
            question=question, 
            options=options, 
            K=modal_K
        )
        
        # Build output item with updated schema
        output_item = {
            "id": item["id"],
            "domain": item["domain"],
            "question": item["question"],
            "choices": item["choices"],  # pass-through
            "gold_label": item["gold_label"],  # pass-through for compatibility
            "gold_index": gold_index,  # NEW: derived from gold_label
            "gold_answer": item["gold_answer"],  # pass-through
            "generic_expert_prompt": generic_expert_prompt
        }
    
    return output_item


def print_character_count_table(items: list[Dict[str, Any]]) -> None:
    """Print character count table for first 3 items."""
    print("\nCharacter count table for first 3 items:")
    
    for i, item in enumerate(items[:3]):
        domain = item["domain"]
        
        print(f"\nItem {i} (Domain: {domain}):")
        print(f"  Generic Expert: {len(item['generic_expert_prompt'])} chars")


def print_preview(items: List[Dict[str, Any]], modal_K: int) -> None:
    """Print preview of first 2 rows with formatted options."""
    print(f"\nPreview of first 2 rows (modal_K={modal_K}):")
    
    for i, item in enumerate(items[:2]):
        print(f"\n--- Row {i+1} ---")
        print(f"Question: {item['question']}")
        print(f"Gold Answer: {item['gold_answer']}")
        
        # Show rendered options only for non-math domains
        if item["domain"] != "math" and "choices" in item:
            options = format_options(item["choices"])
            print("Options:")
            print(options)
        
        # Show first 160 chars of generic expert prompt
        print(f"Generic Expert prompt (first 160 chars): {item['generic_expert_prompt'][:160]}...")


def main():
    parser = argparse.ArgumentParser(description="Build prompt families from direct answer test data")
    parser.add_argument("--in", required=True, dest="input_file", help="Input JSONL file path")
    parser.add_argument("--out", required=True, help="Output JSONL file path")
    parser.add_argument("--limit", type=int, help="Limit number of items to process")
    args = parser.parse_args()
    
    # Load input data
    print(f"Loading data from {args.input_file}...")
    input_data = load_jsonl(args.input_file)
    print(f"Loaded {len(input_data)} items")
    
    # Apply limit if specified
    if args.limit and len(input_data) > args.limit:
        print(f"Limiting to first {args.limit} items")
        input_data = input_data[:args.limit]
    
    # Find modal choice count
    modal_K = find_modal_choice_count(input_data)
    
    # Error if modal_K < 3
    if modal_K < 3:
        raise ValueError(f"Modal choice count {modal_K} is less than 3")
    
    # Filter data to keep only items with modal choice count
    filtered_data = filter_by_choice_count(input_data, modal_K)
    
    if not filtered_data:
        raise ValueError("No items remaining after filtering by choice count")
    
    # Validate all remaining items
    print("Validating filtered data...")
    for i, item in enumerate(filtered_data):
        domain = item.get("domain", "")
        
        if domain == "math":
            # Math problems have different required fields
            required_fields = ["id", "domain", "question", "gold_answer"]
            for field in required_fields:
                if field not in item:
                    raise ValueError(f"Item {item.get('id', 'unknown')} missing field: {field}")
                if field == "id":
                    item[field] = i
                elif not item[field]:
                    raise ValueError(f"Item {item.get('id', 'unknown')} empty field: {field}")
        else:
            # Other domains have choices
            required_fields = ["id", "domain", "question", "choices", "gold_label", "gold_answer"]
            for field in required_fields:
                if field not in item:
                    raise ValueError(f"Item {item.get('id', 'unknown')} missing field: {field}")
                if field == "id":
                    item[field] = i
                elif not item[field]:
                    raise ValueError(f"Item {item.get('id', 'unknown')} empty field: {field}")
            
            # Check choices are non-empty strings
            choices = item["choices"]
            if len(choices) != modal_K:
                raise ValueError(f"Item {item['id']} has {len(choices)} choices, expected {modal_K}")
            
            for j, choice in enumerate(choices):
                if not isinstance(choice, str) or not choice.strip():
                    raise ValueError(f"Item {item['id']} choice {j+1} is not a non-empty string")
            
            # Validate gold_label is valid for modal_K
            gold_label = item["gold_label"]
            try:
                gold_index = compute_gold_index(gold_label)
                if not (1 <= gold_index <= modal_K):
                    raise ValueError(f"Item {item['id']} gold_index {gold_index} not in range [1, {modal_K}]")
            except (ValueError, IndexError) as e:
                raise ValueError(f"Item {item['id']} has invalid gold_label '{gold_label}': {e}")
    
    print(f"Validation passed for {len(filtered_data)} items")
    
    # Build prompts for all filtered items
    print("Building prompts...")
    output_data = []
    for item in filtered_data:
        try:
            output_item = build_prompts(item, modal_K)
            output_data.append(output_item)
        except Exception as e:
            raise ValueError(f"Failed to build prompts for item {item['id']}: {e}")
    
    print(f"Built prompts for {len(output_data)} items")
    
    # Print preview
    print_preview(output_data, modal_K)
    
    # Print final statistics
    domain_counts = Counter(item["domain"] for item in output_data)
    print(f"\nFinal statistics:")
    print(f"Modal choice count (K): {modal_K}")
    print(f"Total items kept: {len(output_data)}")
    print(f"Items by domain: {dict(domain_counts)}")
    
    # Save output data
    print(f"\nSaving prompts to {args.out}...")
    save_jsonl(output_data, args.out)
    print(f"Successfully saved {len(output_data)} items to {args.out}")


if __name__ == "__main__":
    main()