"""
Enhanced batch prompt generation for large-scale experiments.
Generates multiple prompt configurations in a single run with task-question compatibility checking.
"""

import argparse
import os
import itertools
from typing import List, Dict, Set
from scripts.utils.assorted import load_benchmark_names
from scripts.generate_prompts import (
    generate_enhanced_prompt,
    get_prompt_filename,
    SIZE_PATTERNS,
    SYSTEM_PROMPTS,
    QUESTION_TYPES,
)


# Task-Question Compatibility Matrix
# Define which question types make sense for which tasks
TASK_QUESTION_COMPATIBILITY = {
    # Color-based tasks - questions about blue nodes make sense
    "colorDegree1": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorDegree2": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorDegree3": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorMaxDegree": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorMinDegree": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorInternal": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorLeaves": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorNeighbors": {
        "full_output",
        "node_count",
        "edge_count",
        "blue_node_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorPath": {
        "full_output",
        "node_count",
        "edge_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "colorComponents": {
        "full_output",
        "node_count",
        "edge_count",
        "blue_node_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    # Structure-based tasks - blue node questions don't make sense for input
    "addHub": {
        "full_output",
        "node_count",
        "edge_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
    "edgeToNode": {
        "full_output",
        "node_count",
        "edge_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
        # No blue nodes in this transformation
    },
    "removeDegree1": {
        "full_output",
        "node_count",
        "edge_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
        # No blue nodes in this transformation
    },
    "removeDegree2": {
        "full_output",
        "node_count",
        "edge_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
        # No blue nodes in this transformation
    },
    "removeDegree3": {
        "full_output",
        "node_count",
        "edge_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
        # No blue nodes in this transformation
    },
    "bipartitionCompletion": {
        "full_output",
        "node_count",
        "edge_count",
        "blue_node_count",
        "colored_node_count",
        "is_connected",
        "is_tree",
        "has_cycles",
        "max_degree",
        "min_degree",
        "component_count",
    },
}


def get_compatible_questions(task: str, target: str = None) -> Set[str]:
    """Get compatible question types for a given task and target."""
    if task not in TASK_QUESTION_COMPATIBILITY:
        # Default to safe questions for unknown tasks
        return {
            "full_output",
            "node_count",
            "edge_count",
            "is_connected",
            "is_tree",
            "has_cycles",
            "max_degree",
            "min_degree",
            "component_count",
        }

    compatible = set()
    for question in TASK_QUESTION_COMPATIBILITY[task]:
        if ":" in question:
            # Question has target restriction (e.g., "blue_node_count:output")
            question_type, required_target = question.split(":")
            if target == required_target:
                compatible.add(question_type)
        else:
            # Question works for any target
            compatible.add(question)

    return compatible


def get_available_benchmarks_and_graph_types() -> Dict[str, List[str]]:
    """Get all available benchmarks and their graph types."""
    benchmarks = load_benchmark_names()
    benchmark_graph_types = {}

    for benchmark in benchmarks:
        benchmark_dir = f"datasets/{benchmark}"
        if os.path.exists(benchmark_dir):
            graph_types = [
                d
                for d in os.listdir(benchmark_dir)
                if os.path.isdir(os.path.join(benchmark_dir, d))
            ]
            benchmark_graph_types[benchmark] = graph_types

    return benchmark_graph_types


def generate_batch_prompts(
    benchmarks: List[str] = None,
    encodings: List[str] = None,
    patterns: List[str] = None,
    system_prompts: List[str] = None,
    question_types: List[str] = None,
    targets: List[str] = None,
    use_compatibility_filter: bool = True,
    dry_run: bool = False,
    verbose: bool = False,
) -> Dict[str, int]:
    """
    Generate prompts in batch with all specified combinations.

    Returns:
    - Dictionary with generation statistics
    """

    # Set defaults
    if benchmarks is None:
        available = get_available_benchmarks_and_graph_types()
        benchmarks = list(available.keys())

    if encodings is None:
        encodings = ["adjacency"]

    if patterns is None:
        patterns = ["scale_up_3"]

    if system_prompts is None:
        system_prompts = ["none"]

    if question_types is None:
        question_types = ["full_output"]

    if targets is None:
        targets = ["output"]

    # Get available benchmark-graph type combinations
    available = get_available_benchmarks_and_graph_types()

    # Filter to requested benchmarks
    selected_benchmarks = {
        b: graph_types for b, graph_types in available.items() if b in benchmarks
    }

    if verbose:
        print("📊 Batch prompt generation:")
        print(f"   Benchmarks: {len(selected_benchmarks)}")
        print(f"   Encodings: {encodings}")
        print(f"   Patterns: {patterns}")
        print(f"   System prompts: {system_prompts}")
        print(f"   Question types: {question_types}")
        print(f"   Targets: {targets}")
        print(f"   Compatibility filtering: {use_compatibility_filter}")

    stats = {
        "total_combinations": 0,
        "filtered_out": 0,
        "already_exist": 0,
        "generated": 0,
        "failed": 0,
        "by_task": {},
        "by_question": {},
    }

    # Generate all combinations
    for benchmark, graph_types in selected_benchmarks.items():
        stats["by_task"][benchmark] = {"generated": 0, "failed": 0, "skipped": 0}

        for graph_type in graph_types:
            for encoding, pattern, system_prompt in itertools.product(
                encodings, patterns, system_prompts
            ):
                # Get sizes for this pattern
                if pattern in SIZE_PATTERNS:
                    sizes = SIZE_PATTERNS[pattern]
                else:
                    print(f"⚠️ Unknown pattern {pattern}, skipping")
                    continue

                for question_type in question_types:
                    # Determine valid targets for this question type
                    valid_targets = []

                    if question_type in QUESTION_TYPES:
                        for target in targets:
                            if (
                                target in QUESTION_TYPES[question_type]
                                and QUESTION_TYPES[question_type][target] is not None
                            ):
                                valid_targets.append(target)

                    if not valid_targets:
                        if verbose:
                            print(
                                f"   ⚠️ Question {question_type} has no valid targets in {targets}"
                            )
                        continue

                    for target in valid_targets:
                        stats["total_combinations"] += 1

                        # Check task-question compatibility
                        if use_compatibility_filter:
                            compatible_questions = get_compatible_questions(
                                benchmark, target
                            )
                            if question_type not in compatible_questions:
                                stats["filtered_out"] += 1
                                #if verbose:
                                #    print(
                                #        f"   🚫 Filtered: {benchmark} + {question_type}({target}) - incompatible"
                                #    )
                                continue

                        # Track by question type
                        if question_type not in stats["by_question"]:
                            stats["by_question"][question_type] = {
                                "generated": 0,
                                "failed": 0,
                            }

                        # Check if prompt already exists
                        filename = get_prompt_filename(
                            encoding, sizes, system_prompt, question_type, target
                        )
                        prompt_dir = f"datasets/{benchmark}/{graph_type}/prompts"
                        prompt_path = os.path.join(prompt_dir, filename)

                        if os.path.exists(prompt_path):
                            stats["already_exist"] += 1
                            stats["by_task"][benchmark]["skipped"] += 1
                            #if verbose:
                            #    print(f"   ⏩ Exists: {prompt_path}")
                            continue

                        if dry_run:
                            print(f"   🔍 Would generate: {prompt_path}")
                            stats["generated"] += 1
                            continue

                        # Generate the prompt
                        if verbose:
                            print(
                                f"   🔨 Generating: {benchmark}/{graph_type}/{encoding}-{pattern}-{system_prompt}-{question_type}-{target}"
                            )

                        # Create directory
                        os.makedirs(prompt_dir, exist_ok=True)

                        try:
                            prompt = generate_enhanced_prompt(
                                directory=f"{benchmark}/{graph_type}",
                                encoding=encoding,
                                sizes=sizes,
                                system_prompt_type=system_prompt,
                                question_type=question_type,
                                target=target,
                            )

                            if prompt:
                                with open(prompt_path, "w", encoding="utf-8") as f:
                                    f.write(prompt)

                                stats["generated"] += 1
                                stats["by_task"][benchmark]["generated"] += 1
                                stats["by_question"][question_type]["generated"] += 1

                                #if verbose:
                                #    print(f"   ✅ Generated: {prompt_path}")
                            else:
                                stats["failed"] += 1
                                stats["by_task"][benchmark]["failed"] += 1
                                stats["by_question"][question_type]["failed"] += 1
                                print(f"   ❌ Failed: {prompt_path}")

                        except OSError as e:
                            stats["failed"] += 1
                            stats["by_task"][benchmark]["failed"] += 1
                            stats["by_question"][question_type]["failed"] += 1
                            print(f"   ❌ Error generating {prompt_path}: {e}")

    return stats


def main():
    parser = argparse.ArgumentParser(
        description="Generate prompts in batch for large-scale experiments"
    )

    parser.add_argument(
        "--tasks",
        nargs="+",
        default=None,
        help="Tasks to generate prompts for (default: all)",
    )
    parser.add_argument(
        "--encodings",
        nargs="+",
        default=["adjacency"],
        choices=["adjacency", "incident"],
        help="Encoding types (default: adjacency)",
    )
    parser.add_argument(
        "--patterns",
        nargs="+",
        default=["scale_up_3"],
        choices=list(SIZE_PATTERNS.keys()),
        help="Size patterns (default: scale_up_3)",
    )
    parser.add_argument(
        "--system-prompts",
        nargs="+",
        default=["none"],
        choices=list(SYSTEM_PROMPTS.keys()),
        help="System prompts (default: none)",
    )
    parser.add_argument(
        "--question-types",
        nargs="+",
        default=["full_output"],
        choices=list(QUESTION_TYPES.keys()),
        help="Question types (default: full_output)",
    )
    parser.add_argument(
        "--targets",
        nargs="+",
        default=["output"],
        choices=["input", "output"],
        help="Question targets (default: output)",
    )
    parser.add_argument(
        "--all-system-prompts",
        action="store_true",
        help="Use all available system prompts",
    )
    parser.add_argument(
        "--all-questions", action="store_true", help="Use all compatible question types"
    )
    parser.add_argument(
        "--all-targets", action="store_true", help="Use both input and output targets"
    )
    parser.add_argument(
        "--no-compatibility-filter",
        action="store_true",
        help="Disable task-question compatibility filtering",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Show what would be generated without actually generating",
    )
    parser.add_argument("--verbose", action="store_true", help="Show detailed progress")

    args = parser.parse_args()

    # Apply --all flags
    if args.all_system_prompts:
        args.system_prompts = list(SYSTEM_PROMPTS.keys())

    if args.all_questions:
        args.question_types = list(QUESTION_TYPES.keys())

    if args.all_targets:
        args.targets = ["input", "output"]

    # Generate prompts
    stats = generate_batch_prompts(
        benchmarks=args.tasks,
        encodings=args.encodings,
        patterns=args.patterns,
        system_prompts=args.system_prompts,
        question_types=args.question_types,
        targets=args.targets,
        use_compatibility_filter=not args.no_compatibility_filter,
        dry_run=args.dry_run,
        verbose=args.verbose,
    )

    # Print summary
    print("\n" + "=" * 60)
    print("BATCH PROMPT GENERATION SUMMARY")
    print("=" * 60)
    print(f"Total combinations considered: {stats['total_combinations']:,}")
    print(f"Filtered out (incompatible): {stats['filtered_out']:,}")
    print(f"Already existed: {stats['already_exist']:,}")
    print(f"Successfully generated: {stats['generated']:,}")
    print(f"Failed: {stats['failed']:,}")

    if stats["by_task"]:
        print("\nBy task:")
        for task, task_stats in stats["by_task"].items():
            print(
                f"  {task}: {task_stats['generated']} generated, {task_stats['failed']} failed, {task_stats['skipped']} skipped"
            )

    if stats["by_question"]:
        print("\nBy question type:")
        for question, question_stats in stats["by_question"].items():
            print(
                f"  {question}: {question_stats['generated']} generated, {question_stats['failed']} failed"
            )

    print("=" * 60)


if __name__ == "__main__":
    main()
