import os
import argparse
import json
from typing import List, Optional
from scripts.utils.assorted import load_benchmark_names

# =============================================================================
# PROMPT TEMPLATE COMPONENTS
# =============================================================================
# These components are assembled in order to create the final prompt structure

# 1. System prompt (optional, prepended if specified)
SYSTEM_PROMPT_TEMPLATE = "{system_prompt}\n\n"

# 2. Minimal, flexible introduction
INTRODUCTION_TEMPLATE = """Below are {n_examples} examples of input graphs and their corresponding output graphs.

"""

# 3. Early format instruction (NEW - placed right after introduction)
FORMAT_INSTRUCTION_TEMPLATE = """IMPORTANT: Structure your response using XML tags as follows:

<thinking>
Your step-by-step analysis and reasoning here. Explain your thought process, identify patterns from the examples, and work through the problem systematically.
</thinking>

<answer>
{answer_format_specific}
</answer>

"""

# 4. Example pairs (repeated n_examples times)
EXAMPLE_TEMPLATE = """Input {example_num}:
{input_content}

Output {example_num}:
{output_content}

"""

# 5. Final input presentation
FINAL_INPUT_TEMPLATE = """Using these examples, and this final input graph, answer the following question:

Input {final_num}:
{final_input_content}

"""

# 6. Question instruction (varies based on question_type and target)
QUESTION_INSTRUCTION_TEMPLATE = "{question_text}"

# 7. Final format reminder (NEW - replaces old ANSWER_FORMAT_TEMPLATE)
FINAL_FORMAT_REMINDER = """

Remember: Use <thinking></thinking> tags for your analysis and <answer></answer> tags for your final response."""

# =============================================================================
# QUESTION-SPECIFIC ANSWER FORMAT TEMPLATES (NEW)
# =============================================================================

ANSWER_FORMAT_TEMPLATES = {
    "full_output": "The complete output graph in the same format as the examples.",
    "node_count": "Only the number.",
    "edge_count": "Only the number.",
    "blue_node_count": "Only the number.",
    "colored_node_count": "Only the number.",
    "is_connected": "Only 'Yes' or 'No'.",
    "is_tree": "Only 'Yes' or 'No'.",
    "has_cycles": "Only 'Yes' or 'No'.",
    "max_degree": "Only the number.",
    "min_degree": "Only the number.",
    "component_count": "Only the number.",
}

# =============================================================================
# QUESTION TYPES DICTIONARY (updated with new answer formats)
# =============================================================================

QUESTION_TYPES = {
    "full_output": {
        "input": None,  # Not applicable for input
        "output": {
            "question": "What is the corresponding output graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["full_output"],
        },
    },
    "node_count": {
        "input": {
            "question": "How many nodes are there in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["node_count"],
        },
        "output": {
            "question": "How many nodes will be in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["node_count"],
        },
    },
    "edge_count": {
        "input": {
            "question": "How many edges are there in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["edge_count"],
        },
        "output": {
            "question": "How many edges will be in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["edge_count"],
        },
    },
    "blue_node_count": {
        "input": {
            "question": "How many blue nodes are there in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["blue_node_count"],
        },
        "output": {
            "question": "How many blue nodes will be in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["blue_node_count"],
        },
    },
    "colored_node_count": {
        "input": {
            "question": "How many colored nodes (non-grey nodes) are there in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["colored_node_count"],
        },
        "output": {
            "question": "How many colored nodes (non-grey nodes) will be in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["colored_node_count"],
        },
    },
    "is_connected": {
        "input": {
            "question": "Is this input graph connected (i.e., there is a path between every pair of nodes)?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["is_connected"],
        },
        "output": {
            "question": "Will the output graph be connected (i.e., there is a path between every pair of nodes) after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["is_connected"],
        },
    },
    "is_tree": {
        "input": {
            "question": "Is this input graph a tree (i.e., connected and acyclic)?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["is_tree"],
        },
        "output": {
            "question": "Will the output graph be a tree (i.e., connected and acyclic) after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["is_tree"],
        },
    },
    "has_cycles": {
        "input": {
            "question": "Does this input graph contain any cycles?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["has_cycles"],
        },
        "output": {
            "question": "Will the output graph contain any cycles after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["has_cycles"],
        },
    },
    "max_degree": {
        "input": {
            "question": "What is the maximum degree (number of connections) of any node in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["max_degree"],
        },
        "output": {
            "question": "What will be the maximum degree (number of connections) of any node in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["max_degree"],
        },
    },
    "min_degree": {
        "input": {
            "question": "What is the minimum degree (number of connections) of any node in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["min_degree"],
        },
        "output": {
            "question": "What will be the minimum degree (number of connections) of any node in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["min_degree"],
        },
    },
    "component_count": {
        "input": {
            "question": "How many connected components are there in this input graph?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["component_count"],
        },
        "output": {
            "question": "How many connected components will be in the output graph after applying the transformation pattern?",
            "answer_format": ANSWER_FORMAT_TEMPLATES["component_count"],
        },
    },
}

# =============================================================================
# SIZE PATTERNS AND SYSTEM PROMPTS (unchanged from original)
# =============================================================================

SIZE_PATTERNS = {
    "scale_up_4": [5, 10, 15, 15],
    "scale_up_3": [5, 10, 15],
    "scale_up_2": [5, 15],
    "mixed_3": [5, 10, 5],
    "cap10_3": [10, 10, 10],
    "cap25_3": [10, 10, 25],
    "cap50_3": [10, 10, 50],
    "cap100_3": [10, 10, 100],
    "cap250_3": [10, 10, 250],
    "small_4": [5, 4, 5, 5],  # Small examples pattern with 4 examples
    "large_2": [15, 15],  # Large examples with 2 examples
    "progressive_5": [3, 5, 8, 10, 15],  # Progressive increase with 5 examples
}

SYSTEM_PROMPTS = {
    "analyst": "You are a graph analyst. Study the following graph examples carefully and answer the question that follows.",
    "programmer": "You are a graph algorithm developer. Analyze the example graphs and their patterns, then answer the question about the given input.",
    "teacher": "You are a mathematics teacher. Examine these graph examples to understand any patterns, then answer the question clearly and methodically.",
    "none": "",  # Default empty prompt
}


def format_size_pattern(sizes: List[int]) -> str:
    """Format a list of sizes into a compact string representation for filenames."""
    if not sizes:
        return "default"

    # If it matches a predefined pattern, use that name
    for pattern_name, pattern_sizes in SIZE_PATTERNS.items():
        if sizes == pattern_sizes:
            return pattern_name

    # Otherwise, create a hyphen-separated string
    return "-".join(map(str, sizes))


def validate_sizes(benchmark: str, graph_type: str, sizes: List[int]) -> tuple:
    """
    Verify that graphs of the requested sizes exist in the dataset.

    Returns:
    - (is_valid, message): Tuple with validation result and message
    """
    specs_path = f"datasets/{benchmark}/{graph_type}/specs.json"

    try:
        with open(specs_path, "r", encoding="utf-8") as f:
            specs = json.load(f)

        available_sizes = list(specs["graphs_by_size"].keys())

        # Convert all sizes to strings for comparison
        sizes_str = [str(size) for size in sizes]

        missing_sizes = [size for size in sizes_str if size not in available_sizes]

        if missing_sizes:
            return (
                False,
                f"Missing graph sizes: {', '.join(missing_sizes)}. Available sizes: {', '.join(available_sizes)}",
            )

        # Also check that each requested size has enough example pairs
        for size in set(sizes_str):
            needed = sizes_str.count(size)
            pairs = specs["graphs_by_size"][size]["pairs"]
            available = len(pairs)
            if available < needed:
                return (
                    False,
                    f"Not enough pairs for size {size}. Found {available}, need at least {needed}.",
                )

        return True, "All sizes available"

    except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
        return False, f"Error validating sizes: {str(e)}"


def generate_question_section(question_type: str, target: str) -> tuple:
    """
    Generate the question section based on question type and target.

    Parameters:
    - question_type: Type of question to ask
    - target: Either 'input' or 'output'

    Returns:
    - Tuple of (question_text, answer_format_text)
    """
    if question_type not in QUESTION_TYPES:
        raise ValueError(
            f"Unknown question type: {question_type}. Available: {list(QUESTION_TYPES.keys())}"
        )

    if target not in ["input", "output"]:
        raise ValueError(f"Target must be 'input' or 'output', got: {target}")

    question_data = QUESTION_TYPES[question_type][target]

    if question_data is None:
        raise ValueError(
            f"Question type '{question_type}' is not available for target '{target}'"
        )

    return question_data["question"], question_data["answer_format"]


def generate_enhanced_prompt(
    directory: str,
    encoding: str,
    sizes: List[int] = None,
    pattern: str = None,
    system_prompt_type: str = None,
    question_type: str = "full_output",
    target: str = "output",
) -> str:
    """
    Generates a dynamic prompt with customizable example sizes, system prompts, and questions.
    Updated to use XML tags and improved structure.
    FIXED: Prevents duplicate graphs when example size equals test size.

    Parameters:
    - directory (str): The dataset directory (e.g., "colorLeaves/star").
    - encoding (str): The encoding format (e.g., "adjacency", "incident").
    - sizes (List[int], optional): Explicit sizes for examples and test.
    - pattern (str, optional): Predefined size pattern (e.g., "scale_up_3").
    - system_prompt_type (str, optional): Type of system prompt to prepend.
    - question_type (str): Type of question to ask (default: "full_output").
    - target (str): Target for the question - "input" or "output" (default: "output").

    Returns:
    - str: The formatted prompt.
    - None: If validation fails or duplicates are detected.
    """
    # Resolve the size pattern to use
    actual_sizes = sizes or (
        SIZE_PATTERNS.get(pattern, [10, 10]) if pattern else [10, 10]
    )

    # Parse the directory to get task and graph type
    parts = directory.split("/")
    if len(parts) >= 2:
        task_name, graph_type = parts[0], parts[1]
    else:
        task_name = parts[0]
        graph_type = "unknown"

    # Validate that all requested sizes exist
    is_valid, message = validate_sizes(task_name, graph_type, actual_sizes)
    if not is_valid:
        print(
            f"Error generating prompt for {directory} with sizes {actual_sizes}: {message}"
        )
        return None

    # Validate question type and target
    try:
        question_text, answer_format = generate_question_section(question_type, target)
    except ValueError as e:
        print(f"Error generating question section: {e}")
        return None

    # Get the system prompt if specified
    system_prompt = SYSTEM_PROMPTS.get(system_prompt_type, "")

    # Start building the prompt using template components
    prompt = ""

    # 1. Add system prompt if specified
    if system_prompt:
        prompt += SYSTEM_PROMPT_TEMPLATE.format(system_prompt=system_prompt)

    n_examples = len(actual_sizes) - 1  # Last size is for the test case

    # 2. Add introduction
    prompt += INTRODUCTION_TEMPLATE.format(n_examples=n_examples)

    # 3. Add early format instruction with question-specific answer format
    prompt += FORMAT_INSTRUCTION_TEMPLATE.format(answer_format_specific=answer_format)

    specs_path = f"datasets/{task_name}/{graph_type}/specs.json"

    try:
        # Load the specs file
        with open(specs_path, "r", encoding="utf-8") as f:
            specs = json.load(f)

        # Track used input and output graphs
        used_inputs = set()
        used_outputs = set()
        
        # Track which indices were used for each size
        used_indices_by_size = {}

        # 4. Add input-output pairs based on specified sizes
        for i in range(n_examples):
            size = str(actual_sizes[i])

            pairs = specs["graphs_by_size"][size]["pairs"]
            # Sort by index to ensure we get the right sequence
            sorted_pairs = sorted(pairs, key=lambda p: p["index"])
            pair = sorted_pairs[
                i % len(sorted_pairs)
            ]  # Use modulo to avoid index errors

            # Get file paths
            input_path = os.path.join(
                "datasets", task_name, graph_type, pair["files"]["input"][encoding]
            )
            output_path = os.path.join(
                "datasets", task_name, graph_type, pair["files"]["output"][encoding]
            )

            # Check for duplicates
            if input_path in used_inputs or output_path in used_outputs:
                print(
                    f"Error: Duplicate input or output graph detected in prompt generation. "
                    f"Input: {input_path}, Output: {output_path}"
                )
                return None

            # Mark input and output as used
            used_inputs.add(input_path)
            used_outputs.add(output_path)
            
            # Track which index was used for this size
            if size not in used_indices_by_size:
                used_indices_by_size[size] = set()
            used_indices_by_size[size].add(pair["index"])

            try:
                with open(input_path, "r", encoding="utf-8") as f:
                    input_content = f.read().strip()
                with open(output_path, "r", encoding="utf-8") as f:
                    output_content = f.read().strip()

                prompt += EXAMPLE_TEMPLATE.format(
                    example_num=i + 1,
                    input_content=input_content,
                    output_content=output_content,
                )
            except FileNotFoundError as e:
                print(f"Warning: File not found: {e.filename}")
                return None

        # 5. Add final input for prediction/analysis
        test_size = str(actual_sizes[-1])
        test_pairs = specs["graphs_by_size"][test_size]["pairs"]
        sorted_test_pairs = sorted(test_pairs, key=lambda p: p["index"])
        
        # FIXED: Choose test pair index that avoids duplicates
        test_pair_index = 0
        
        # If this size was already used in examples, pick a different index
        if test_size in used_indices_by_size:
            used_indices = used_indices_by_size[test_size]
            # Find the first index not already used
            for potential_index in range(len(sorted_test_pairs)):
                if sorted_test_pairs[potential_index]["index"] not in used_indices:
                    test_pair_index = potential_index
                    break
            else:
                # If all indices were used (shouldn't happen with n_pairs=3), 
                # fall back to incrementing from the highest used index
                max_used_index = max(idx for idx in used_indices)
                test_pair_index = (max_used_index % len(sorted_test_pairs))
                print(
                    f"Warning: All indices for size {test_size} were used in examples. "
                    f"Using index {test_pair_index} which may be a duplicate."
                )
        
        test_pair = sorted_test_pairs[test_pair_index]

        test_input_path = os.path.join(
            "datasets", task_name, graph_type, test_pair["files"]["input"][encoding]
        )
        
        # Final duplicate check
        if test_input_path in used_inputs:
            print(
                f"Warning: Test input {test_input_path} would duplicate an example. "
                f"Attempting to use next available index."
            )
            # Try the next index
            test_pair_index = (test_pair_index + 1) % len(sorted_test_pairs)
            test_pair = sorted_test_pairs[test_pair_index]
            test_input_path = os.path.join(
                "datasets", task_name, graph_type, test_pair["files"]["input"][encoding]
            )

        try:
            with open(test_input_path, "r", encoding="utf-8") as f:
                final_input_content = f.read().strip()

            prompt += FINAL_INPUT_TEMPLATE.format(
                final_num=n_examples + 1, final_input_content=final_input_content
            )

            # 6. Add the question instruction
            prompt += QUESTION_INSTRUCTION_TEMPLATE.format(question_text=question_text)

            # 7. Add the final format reminder
            prompt += FINAL_FORMAT_REMINDER

        except FileNotFoundError:
            print(f"Error: Test input file not found for size {test_size}")
            return None

    except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
        print(f"Error loading specs file: {e}")
        return None

    return prompt


def get_prompt_filename(
    encoding: str,
    sizes: List[int],
    system_prompt_type: str,
    question_type: str = "full_output",
    target: str = "output",
) -> str:
    """Generate a standardized filename for a prompt using hyphens as delimiters."""
    size_pattern_str = format_size_pattern(sizes)
    system_prompt_str = system_prompt_type or "none"
    num_pairs = len(sizes) - 1

    # New format: encoding-pattern-system-pairs-question-target.txt
    # This avoids conflicts with question types that contain underscores
    if question_type == "full_output":
        # For backward compatibility, full_output doesn't include question-target suffix
        return f"{encoding}-{size_pattern_str}-{system_prompt_str}-{num_pairs}.txt"
    else:
        return f"{encoding}-{size_pattern_str}-{system_prompt_str}-{num_pairs}-{question_type}-{target}.txt"


def load_prompt_if_exists(
    benchmark: str,
    graph_type: str,
    encoding: str,
    sizes: List[int],
    system_prompt_type: str,
    question_type: str = "full_output",
    target: str = "output",
) -> Optional[str]:
    """Try to load an existing prompt file if it exists."""
    filename = get_prompt_filename(
        encoding, sizes, system_prompt_type, question_type, target
    )
    prompt_path = f"datasets/{benchmark}/{graph_type}/prompts/{filename}"

    if os.path.exists(prompt_path):
        try:
            with open(prompt_path, "r", encoding="utf-8") as f:
                return f.read()
        except (OSError, IOError) as e:
            print(f"Error reading prompt file: {e}")
            return None

    return None


def collect_in_json(
    benchmarks: Optional[List[str]] = None,
    patterns: Optional[List[str]] = None,  # Changed to support multiple patterns
    system_prompt: Optional[str] = None,
    question_type: Optional[str] = None,
    target: Optional[str] = None,
    output_file: str = "datasets/prompts/prompts.json",
):
    """
    Collects all existing prompts matching the criteria into a single JSON file
    for batch inference with local models like Qwen.

    The output JSON is a list of prompt entries. Each entry includes:
      - id: unique prompt identifier
      - system_prompt: the system prompt to prepend
      - text: the prompt content (examples, question, etc.)
      - metadata: additional information about the prompt

    Parameters:
    - benchmarks: List of benchmark names to include (None = all)
    - patterns: List of size patterns to filter (None = all)
    - system_prompt: System prompt type to filter (None = all)
    - question_type: Question type to filter (None = all)
    - target: Target to filter (None = all)
    - output_file: Path where to save the collected prompts JSON
    """
    print("🔍 Collecting existing prompts for batch inference...")

    if benchmarks is None:
        # Get all available benchmarks
        datasets_dir = "datasets"
        if not os.path.exists(datasets_dir):
            print("❌ No datasets directory found!")
            return
        benchmarks = [
            d
            for d in os.listdir(datasets_dir)
            if os.path.isdir(os.path.join(datasets_dir, d)) and d != "prompts"
        ]

    collected_prompts = []
    total_found = 0

    for benchmark in benchmarks:
        benchmark_dir = os.path.join("datasets", benchmark)
        if not os.path.isdir(benchmark_dir):
            print(f"⚠️ No dataset directory found for {benchmark}")
            continue

        available_graph_types = [
            d
            for d in os.listdir(benchmark_dir)
            if os.path.isdir(os.path.join(benchmark_dir, d))
        ]
        found_prompts_in_benchmark = False
        for graph_type in available_graph_types:
            prompts_dir = os.path.join(benchmark_dir, graph_type, "prompts")
            if not os.path.isdir(prompts_dir):
                continue
            found_prompts_in_benchmark = True
            print(f"📂 Scanning {benchmark}/{graph_type} prompts...")

            for filename in os.listdir(prompts_dir):
                if not filename.endswith(".txt"):
                    continue

                # Parse filename by splitting on hyphens into fields
                # Expected format: encoding-size_pattern-system_prompt-n_pairs[-question_type-target].txt
                parts = filename.replace(".txt", "").split("-")
                if len(parts) not in (4, 6):
                    continue

                file_encoding = parts[0]
                file_pattern = parts[1]
                file_system_prompt = parts[2]
                file_n_pairs = parts[3]

                if len(parts) == 4:
                    file_question_type = "full_output"
                    file_target = "output"
                else:
                    file_question_type = parts[4]
                    file_target = parts[5]

                # Apply filters - UPDATED TO HANDLE MULTIPLE PATTERNS
                if patterns and file_pattern not in patterns:
                    continue
                if system_prompt is not None and file_system_prompt != system_prompt:
                    continue
                if question_type and file_question_type != question_type:
                    continue
                if target and file_target != target:
                    continue

                # Read the prompt content
                prompt_path = os.path.join(prompts_dir, filename)
                try:
                    with open(prompt_path, "r", encoding="utf-8") as f:
                        prompt_content = f.read().strip()

                    # Create unique ID for this prompt
                    prompt_id = (
                        f"{benchmark}_{graph_type}_{filename.replace('.txt', '')}"
                    )

                    # Add to collection
                    collected_prompts.append(
                        {
                            "id": prompt_id,
                            "text": prompt_content,
                            "metadata": {
                                "benchmark": benchmark,
                                "graph_type": graph_type,
                                "encoding": file_encoding,
                                "pattern": file_pattern,
                                "system_prompt": file_system_prompt,
                                "question_type": file_question_type,
                                "target": file_target,
                                "n_pairs": file_n_pairs,
                                "filename": filename,
                                "prompt_path": prompt_path,
                            },
                        }
                    )
                    total_found += 1

                except (OSError, IOError, UnicodeDecodeError) as e:
                    print(f"⚠️ Error reading {prompt_path}: {e}")
                    continue

        if not found_prompts_in_benchmark:
            print(f"⚠️ No prompts directory found for {benchmark}")
            continue

    if not collected_prompts:
        print("❌ No prompts found matching the criteria!")
        return

    # Create the output structure expected by llm-inference.py:
    # a list of prompt entries, each with its own system_prompt and body text (sans prefix)
    output_data = []
    for prompt in collected_prompts:
        sys_key = prompt["metadata"].get("system_prompt", "")
        sys_text = SYSTEM_PROMPTS.get(sys_key, "")
        prefix = f"{sys_text}\n\n" if sys_text else ""
        full_text = prompt["text"]
        body_text = (
            full_text[len(prefix) :]
            if prefix and full_text.startswith(prefix)
            else full_text
        )
        output_data.append(
            {
                "id": prompt["id"],
                "system_prompt": sys_text,
                "text": body_text,
                "metadata": prompt["metadata"],
            }
        )

    # Create output directory if it doesn't exist
    output_dir = os.path.dirname(output_file)
    os.makedirs(output_dir, exist_ok=True)

    # Save the collected prompts
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)

    print(f"✅ Collected {total_found} prompts from {len(benchmarks)} benchmarks")
    print(f"💾 Saved to: {output_file}")
    print("📊 Breakdown by benchmark:")

    # Show breakdown by benchmark
    benchmark_counts = {}
    for prompt in collected_prompts:
        benchmark = prompt["metadata"]["benchmark"]
        benchmark_counts[benchmark] = benchmark_counts.get(benchmark, 0) + 1

    for benchmark, count in sorted(benchmark_counts.items()):
        print(f"   - {benchmark}: {count} prompts")

    # Show breakdown by pattern if multiple patterns were used
    if patterns and len(patterns) > 1:
        pattern_counts = {}
        for prompt in collected_prompts:
            pattern = prompt["metadata"]["pattern"]
            pattern_counts[pattern] = pattern_counts.get(pattern, 0) + 1

        print("📋 Patterns:")
        for pattern, count in sorted(pattern_counts.items()):
            print(f"   - {pattern}: {count} prompts")

    # Show breakdown by question type if multiple
    question_counts = {}
    for prompt in collected_prompts:
        qt = prompt["metadata"]["question_type"]
        question_counts[qt] = question_counts.get(qt, 0) + 1

    if len(question_counts) > 1:
        print("📋 Question types:")
        for qt, count in sorted(question_counts.items()):
            print(f"   - {qt}: {count} prompts")

    return output_file


def main():
    parser = argparse.ArgumentParser(
        description="Generate prompts for graph transformation tasks or collect existing prompts for batch inference"
    )

    # Original prompt generation arguments
    parser.add_argument(
        "--benchmarks", nargs="+", help="Benchmark names to generate prompts for"
    )
    parser.add_argument(
        "--pattern",
        "--patterns",  # Allow both --pattern and --patterns
        nargs="+",  # Accept multiple patterns
        choices=list(SIZE_PATTERNS.keys()),
        default=None,
        help="Size pattern(s) for examples. Can specify multiple patterns for --collect mode (default: scale_up_3 for generation; in --collect mode includes all patterns when unspecified)",
    )
    parser.add_argument(
        "--sizes", nargs="+", type=int, help="Custom example sizes (overrides pattern)"
    )
    parser.add_argument(
        "--system_prompt",
        choices=list(SYSTEM_PROMPTS.keys()),
        default="none",
        help="System prompt to use (default: none)",
    )
    parser.add_argument(
        "--question_type",
        choices=list(QUESTION_TYPES.keys()),
        default="full_output",
        help="Type of question to ask (default: full_output)",
    )
    parser.add_argument(
        "--target",
        choices=["input", "output"],
        help="Target for the question (input or output graph). Auto-determined if not specified.",
    )
    parser.add_argument(
        "--encoding",
        choices=["adjacency", "incident"],
        default="adjacency",
        help="Graph encoding format (default: adjacency)",
    )

    # New collection functionality
    parser.add_argument(
        "--collect",
        action="store_true",
        help="Collect existing prompts into a single JSON file for batch inference",
    )
    parser.add_argument(
        "--output_file",
        default="llm-inference/prompts.json",
        help="Output file for collected prompts (default: llm-inference/prompts.json)",
    )

    # Listing options
    parser.add_argument(
        "--list_patterns", action="store_true", help="List available size patterns"
    )
    parser.add_argument(
        "--list_questions",
        action="store_true",
        help="List available question types and their targets",
    )
    parser.add_argument(
        "--list_system_prompts",
        action="store_true",
        help="List available system prompts",
    )

    args = parser.parse_args()

    # Handle listing options
    if args.list_patterns:
        print("Available size patterns:")
        for name, sizes in SIZE_PATTERNS.items():
            print(f"  {name}: {sizes}")
        return

    if args.list_system_prompts:
        print("Available system prompts:")
        for name, prompt in SYSTEM_PROMPTS.items():
            preview = prompt[:100] + "..." if len(prompt) > 100 else prompt
            print(f"  {name}: {preview}")
        return

    if args.list_questions:
        print("Available question types and targets:")
        for qt_name, qt_data in QUESTION_TYPES.items():
            targets = list(qt_data.keys())
            print(f"  {qt_name}: {targets}")
        return

    # Handle collection mode
    if args.collect:
        collect_in_json(
            benchmarks=args.benchmarks,
            patterns=args.pattern,
            system_prompt=args.system_prompt,
            question_type=(
                args.question_type if args.question_type != "full_output" else None
            ),
            target=args.target,
            output_file=args.output_file,
        )
        return

    # Set default pattern for prompt generation when not in collect mode
    if not args.collect and args.pattern is None:
        args.pattern = ["scale_up_3"]  # Make it a list
    elif not args.collect and isinstance(args.pattern, list) and len(args.pattern) > 1:
        # For non-collect mode, only use the first pattern if multiple are provided
        print(
            f"⚠️ Multiple patterns specified for generation mode. Using only: {args.pattern[0]}"
        )
        args.pattern = args.pattern[0]
    elif not args.collect and isinstance(args.pattern, list):
        # Single pattern in list, extract it
        args.pattern = args.pattern[0]
    # Original prompt generation logic
    if not args.benchmarks:
        print(
            "❌ Please specify benchmarks to generate prompts for, or use --collect to collect existing prompts"
        )
        return

    # Display available patterns if requested
    if args.list_patterns:
        print("Available size patterns:")
        for name, sizes in SIZE_PATTERNS.items():
            print(f"  {name}: {sizes}")
        print("\nAvailable system prompts:")
        for name, prompt in SYSTEM_PROMPTS.items():
            if prompt:
                preview = prompt[:50] + "..." if len(prompt) > 50 else prompt
                print(f"  {name}: {preview}")
            else:
                print(f"  {name}: (empty)")
        return

    # Display available question types if requested
    if args.list_questions:
        print("Available question types:")
        for question_type, targets in QUESTION_TYPES.items():
            available_targets = [
                target for target, data in targets.items() if data is not None
            ]
            print(f"  {question_type}: targets {available_targets}")
        return

    # Validate question type and target combination
    if args.question_type not in QUESTION_TYPES:
        print(f"Error: Unknown question type '{args.question_type}'")
        print(f"Available question types: {list(QUESTION_TYPES.keys())}")
        return

    if QUESTION_TYPES[args.question_type][args.target] is None:
        print(
            f"Error: Question type '{args.question_type}' is not available for target '{args.target}'"
        )
        available_targets = [
            target
            for target, data in QUESTION_TYPES[args.question_type].items()
            if data is not None
        ]
        print(f"Available targets for '{args.question_type}': {available_targets}")
        return

    # Process benchmarks argument
    if (
        isinstance(args.benchmarks, str)
        and args.benchmarks == "all"
        or "all" in args.benchmarks
    ):
        benchmarks = load_benchmark_names()
    else:
        benchmarks = args.benchmarks

    # Process each benchmark
    success_count = 0
    failure_count = 0
    skipped_count = 0

    for benchmark in benchmarks:
        # Determine graph types available for this benchmark
        available_graph_types = []
        benchmark_dir = f"datasets/{benchmark}"
        if os.path.exists(benchmark_dir):
            available_graph_types = [
                d
                for d in os.listdir(benchmark_dir)
                if os.path.isdir(os.path.join(benchmark_dir, d))
            ]

        # Filter graph types if specified
        if args.graph_types:
            selected_graph_types = [
                gt for gt in available_graph_types if gt in args.graph_types
            ]
        else:
            selected_graph_types = available_graph_types

        if not selected_graph_types:
            print(f"No matching graph types found for benchmark {benchmark}, skipping.")
            continue

        for graph_type in selected_graph_types:
            for encoding_type in args.encodings:
                # Handle pattern - convert to single value for generation
                if isinstance(args.pattern, list):
                    current_pattern = args.pattern[0]
                else:
                    current_pattern = args.pattern

                # Resolve size pattern
                sizes = args.sizes
                if not sizes and current_pattern:
                    sizes = SIZE_PATTERNS.get(current_pattern)

                # Use the enhanced prompt generator
                directory = f"{benchmark}/{graph_type}"

                # Resolve size pattern
                sizes = args.sizes
                if not sizes and args.pattern:
                    sizes = SIZE_PATTERNS.get(args.pattern)

                if not sizes:
                    sizes = [10, 10]  # Default if nothing specified

                # Check if prompt already exists
                filename = get_prompt_filename(
                    encoding_type,
                    sizes,
                    args.system_prompt,
                    args.question_type,
                    args.target,
                )
                prompt_dir = f"datasets/{benchmark}/{graph_type}/prompts"
                os.makedirs(prompt_dir, exist_ok=True)
                prompt_path = os.path.join(prompt_dir, filename)

                if os.path.exists(prompt_path) and not args.overwrite:
                    if args.verbose:
                        print(
                            f"⏩ Prompt already exists: {prompt_path} (use --overwrite to regenerate)"
                        )
                    skipped_count += 1
                    continue

                if args.verbose:
                    print(
                        f"Generating prompt for {directory}, encoding: {encoding_type}, sizes: {sizes}, question: {args.question_type}({args.target})"
                    )

                # Generate the prompt
                prompt = generate_enhanced_prompt(
                    directory=directory,
                    encoding=encoding_type,
                    sizes=sizes,
                    system_prompt_type=args.system_prompt,
                    question_type=args.question_type,
                    target=args.target,
                )

                if not prompt:
                    print(
                        f"❌ Failed to generate prompt for {directory} with encoding {encoding_type}"
                    )
                    failure_count += 1
                    continue

                # Save the prompt
                if not args.dry_run:
                    with open(prompt_path, "w", encoding="utf-8") as f:
                        f.write(prompt)

                    if args.verbose:
                        print(f"✅ Saved prompt to {prompt_path}")
                else:
                    if args.verbose:
                        print(f"🔍 Dry run: Would have saved prompt to {prompt_path}")
                        print(f"Prompt preview:\n{prompt[:500]}...\n")

                success_count += 1

    print(
        f"\nPrompt generation complete: {success_count} successes, {failure_count} failures, {skipped_count} skipped"
    )


if __name__ == "__main__":
    main()
