import argparse
import os
import glob

# from models.qwen_local.qwen_interface import solve_graph as solve_graph_qwen
from models.google_gemini.query_gemini import solve_graph as solve_graph_gemini
from models.openai.query_openai import solve_graph as solve_graph_openai
from scripts.utils.metadata import extract_prompt_metadata
from scripts.utils.assorted import load_benchmark_names
from scripts.generate_prompts import (
    generate_enhanced_prompt,
    get_prompt_filename,
    format_size_pattern,
    SIZE_PATTERNS,
)  


def run_all_prompts(
    benchmarks_to_run="all",
    graph_types=None,
    model_backend="gemini-2.0-flash-lite",
):
    """
    Runs all existing prompts found in the specified benchmarks and graph types.
    Updated to use the new hyphen-based filename format.

    Parameters:
    - benchmarks_to_run: List of benchmark names or "all"
    - graph_types: List of graph types or None for all
    - model_backend: Model to use
    """
    if benchmarks_to_run == "all":
        benchmarks_to_run = load_benchmark_names()

    print(f"Scanning for existing prompts in benchmarks: {benchmarks_to_run}")

    # Count total prompts for progress tracking
    total_prompts = 0
    prompt_list = []

    # Find all existing prompts
    for benchmark in benchmarks_to_run:
        benchmark_dir = f"datasets/{benchmark}"
        if not os.path.exists(benchmark_dir):
            continue

        available_graph_types = [
            d
            for d in os.listdir(benchmark_dir)
            if os.path.isdir(os.path.join(benchmark_dir, d))
        ]

        # Filter graph types if specified
        selected_graph_types = (
            available_graph_types
            if graph_types is None
            else [gt for gt in available_graph_types if gt in graph_types]
        )

        for graph_type in selected_graph_types:
            prompts_dir = f"{benchmark_dir}/{graph_type}/prompts"
            if not os.path.exists(prompts_dir):
                continue

            # Find all prompt files
            prompt_files = glob.glob(f"{prompts_dir}/*.txt")

            for prompt_path in prompt_files:
                metadata = extract_prompt_metadata(prompt_path)
                if metadata:
                    prompt_list.append(
                        {
                            "benchmark": benchmark,
                            "graph_type": graph_type,
                            "prompt_path": prompt_path,
                            "metadata": metadata,
                        }
                    )
                    total_prompts += 1

    print(f"Found {total_prompts} existing prompts to run")

    # Process each prompt
    for i, prompt_info in enumerate(prompt_list, 1):
        benchmark = prompt_info["benchmark"]
        graph_type = prompt_info["graph_type"]
        prompt_path = prompt_info["prompt_path"]
        metadata = prompt_info["metadata"]

        encoding = metadata["encoding"]
        size_pattern = metadata["size_pattern"]
        system_prompt = metadata["system_prompt"]
        question_type = metadata["question_type"]
        target = metadata["target"]

        print(
            f"[{i}/{total_prompts}] Processing {benchmark}/{graph_type}/{encoding}-{size_pattern}-{system_prompt}-{question_type}-{target}..."
        )

        # Load the prompt
        try:
            with open(prompt_path, "r", encoding="utf-8") as f:
                prompt = f.read()
        except (OSError, IOError) as e:
            print(f"❌ Error reading prompt file {prompt_path}: {e}")
            continue

        # Construct response path with new hyphen-based format
        response_dir = f"datasets/{benchmark}/{graph_type}/responses"
        os.makedirs(response_dir, exist_ok=True)

        # Generate response filename with new hyphen-based format
        if question_type == "full_output":
            response_filename = f"{encoding}-{size_pattern}-{system_prompt}-{model_backend}.txt"
        else:
            response_filename = f"{encoding}-{size_pattern}-{system_prompt}-{question_type}-{target}-{model_backend}.txt"
        
        response_path = os.path.join(response_dir, response_filename)

        # Check if response already exists
        if os.path.exists(response_path):
            print(f"Response already exists at {response_path}, skipping.")
            continue

        # Query selected model and save response
        if model_backend.startswith("gemini"):
            solve_graph_gemini(prompt, response_path, model_backend)
        elif any(model_backend.startswith(prefix) for prefix in ["gpt-4", "o3"]):
            solve_graph_openai(prompt, response_path, model_backend)
        # elif model_backend == "qwen-local":
        #    solve_graph_qwen(prompt, response_path)
        else:
            print(f"⚠️ Unrecognized model_backend: {model_backend}\n")


def run_benchmarks(
    benchmarks_to_run="all",
    graph_types=None,
    encoding_types=None,
    pattern=None,
    sizes=None,
    system_prompt=None,
    question_type="full_output",
    target="output",
    model_backend="gemini-2.0-flash-lite",
    generate_missing=False,
):
    """
    Runs benchmark evaluations using pre-generated prompts.
    Updated to use the new hyphen-based filename format.

    Parameters:
    - benchmarks_to_run (list[str] or str): List of benchmark names to run, or "all".
    - graph_types (list[str] or None): List of graph types to use, or None for all.
    - encoding_types (list[str] or None): List of encoding types to use.
    - pattern (str): Size pattern name (e.g., "scale_up_3").
    - sizes (list[int]): Specific sizes for examples and test.
    - system_prompt (str): System prompt type to use.
    - question_type (str): Type of question to ask (default: "full_output").
    - target (str): Target for the question - "input" or "output" (default: "output").
    - model_backend (str): Which model to use.
    - generate_missing (bool): Whether to generate missing prompts on-the-fly.
    """
    # Set default values if None
    if encoding_types is None:
        encoding_types = ["adjacency"]

    if benchmarks_to_run == "all":
        benchmarks_to_run = load_benchmark_names()

    print(f"Running benchmarks for: {benchmarks_to_run}")
    print(f"Encoding types: {encoding_types}")
    print(f"Question type: {question_type} (target: {target})")

    # Determine sizes to use (for filename lookup)
    actual_sizes = None
    if sizes:
        actual_sizes = sizes
    elif pattern:
        actual_sizes = SIZE_PATTERNS.get(pattern)

    if not actual_sizes:
        # Default fallback
        actual_sizes = [10, 10]

    total_tasks = 0
    processed_tasks = 0

    # First, count the total number of tasks for progress reporting
    for benchmark in benchmarks_to_run:
        benchmark_dir = f"datasets/{benchmark}"
        if not os.path.exists(benchmark_dir):
            continue

        available_graph_types = [
            d
            for d in os.listdir(benchmark_dir)
            if os.path.isdir(os.path.join(benchmark_dir, d))
        ]
        selected_graph_types = (
            available_graph_types
            if graph_types is None
            else [gt for gt in available_graph_types if gt in graph_types]
        )

        for graph_type in selected_graph_types:
            for encoding_type in encoding_types:
                # This is now just for counting, we don't check individual files
                total_tasks += 1

    # Now process each task
    for benchmark in benchmarks_to_run:
        # Determine graph types available for this benchmark
        available_graph_types = []
        benchmark_dir = f"datasets/{benchmark}"
        if os.path.exists(benchmark_dir):
            available_graph_types = [
                d
                for d in os.listdir(benchmark_dir)
                if os.path.isdir(os.path.join(benchmark_dir, d))
            ]

        # Filter graph types if specified
        if graph_types:
            selected_graph_types = [
                gt for gt in available_graph_types if gt in graph_types
            ]
        else:
            selected_graph_types = available_graph_types

        if not selected_graph_types:
            print(f"No matching graph types found for benchmark {benchmark}, skipping.")
            continue

        for graph_type in selected_graph_types:
            for encoding_type in encoding_types:
                processed_tasks += 1

                # Generate the prompt filename with new format
                filename = get_prompt_filename(
                    encoding_type, actual_sizes, system_prompt or "none", question_type, target
                )
                base_dir = f"datasets/{benchmark}/{graph_type}"
                prompts_dir = f"{base_dir}/prompts"

                # For responses, we now use "default" rather than size categories
                responses_dir = f"{base_dir}/responses"

                # Make sure directories exist
                os.makedirs(prompts_dir, exist_ok=True)
                os.makedirs(responses_dir, exist_ok=True)

                prompt_path = os.path.join(prompts_dir, filename)

                # Check if the prompt exists
                if not os.path.exists(prompt_path):
                    if generate_missing:
                        print(f"Generating missing prompt: {prompt_path}")
                        # Use the generator function with all parameters
                        prompt = generate_enhanced_prompt(
                            directory=f"{benchmark}/{graph_type}",
                            encoding=encoding_type,
                            sizes=actual_sizes,
                            system_prompt_type=system_prompt or "none",
                            question_type=question_type,
                            target=target,
                        )

                        if not prompt:
                            print(
                                f"❌ Failed to generate prompt for {benchmark}/{graph_type}, skipping."
                            )
                            continue

                        # Save the prompt
                        with open(prompt_path, "w", encoding="utf-8") as f:
                            f.write(prompt)
                    else:
                        print(
                            f"⚠️ Prompt not found: {prompt_path}, skipping. Use --generate_missing to create it."
                        )
                        continue

                # Load the prompt
                with open(prompt_path, "r", encoding="utf-8") as f:
                    prompt = f.read()

                # Construct response filename with new hyphen-based format
                size_pattern_str = format_size_pattern(actual_sizes)
                system_str = system_prompt or "none"
                
                if question_type == "full_output":
                    response_filename = f"{encoding_type}-{size_pattern_str}-{system_str}-{model_backend}.txt"
                else:
                    response_filename = f"{encoding_type}-{size_pattern_str}-{system_str}-{question_type}-{target}-{model_backend}.txt"
                
                response_path = os.path.join(responses_dir, response_filename)

                print(
                    f"[{processed_tasks}/{total_tasks}] Processing {benchmark}/{graph_type}/{encoding_type} with pattern {size_pattern_str}, question {question_type}({target})..."
                )

                # Check if response already exists and skip if it does
                if os.path.exists(response_path):
                    print(f"Response already exists at {response_path}, skipping.")
                    continue

                # Query selected model and save response
                if model_backend.startswith("gemini"):
                    solve_graph_gemini(prompt, response_path, model_backend)
                elif any(
                    model_backend.startswith(prefix) for prefix in ["gpt-4", "o3"]
                ):
                    solve_graph_openai(prompt, response_path, model_backend)
                # elif model_backend == "qwen-local":
                #    solve_graph_qwen(prompt, response_path)
                else:
                    print(f"⚠️ Unrecognized model_backend: {model_backend}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run benchmark evaluations")

    parser.add_argument(
        "--benchmarks",
        nargs="+",
        default="all",
        help="Specific benchmarks to run (default: all)",
    )
    parser.add_argument(
        "--graph_types",
        nargs="+",
        default=None,
        help="Specific graph types to use (default: all available)",
    )
    parser.add_argument(
        "--encodings",
        nargs="+",
        default=["adjacency"],
        choices=["adjacency", "incident", "expert", "all"],
        help="Encoding types to use (default: adjacency, use 'all' for all encodings)",
    )
    parser.add_argument(
        "--example_sizes",
        nargs="+",
        type=int,
        default=None,
        help="Specific sizes for examples (e.g., 5 5 15)",
    )
    parser.add_argument(
        "--pattern",
        type=str,
        default=None,
        help="Size pattern to use (e.g., scale_up_3)",
    )
    parser.add_argument(
        "--system_prompt",
        type=str,
        default=None,
        choices=["analyst", "programmer", "teacher", "none"],
        help="System prompt type to prepend",
    )
    parser.add_argument(
        "--question_type",
        type=str,
        default="full_output",
        help="Type of question to ask (default: full_output)",
    )
    parser.add_argument(
        "--target",
        type=str,
        default="output",
        choices=["input", "output"],
        help="Target for the question - input or output graph (default: output)",
    )
    parser.add_argument(
        "--model_backend",
        default="gemini-2.0-flash-lite",
        choices=[
            "gpt-4o-mini",
            "gpt-4.1-nano",
            "o3-mini",
            "gemini-2.0-flash-lite",
            "gemini-2.5-pro-preview-03-25",
            "qwen-local",
        ],
        help="Model backend to use (default: gemini-2.0-flash-lite)",
    )
    parser.add_argument(
        "--generate_missing",
        action="store_true",
        help="Generate missing prompts on-the-fly",
    )
    parser.add_argument(
        "--run_all_prompts",
        action="store_true",
        help="Run all existing prompts for the specified benchmarks",
    )
    parser.add_argument(
        "--list",
        action="store_true",
        help="List all available benchmarks and their graph types",
    )

    args = parser.parse_args()

    # List available benchmarks if requested
    if args.list:
        benchmarks = load_benchmark_names()
        print("\n📋 Available Benchmarks:")
        for benchmark in sorted(benchmarks):
            benchmark_dir = f"datasets/{benchmark}"
            if os.path.exists(benchmark_dir):
                graph_types = [
                    d
                    for d in os.listdir(benchmark_dir)
                    if os.path.isdir(os.path.join(benchmark_dir, d))
                ]
                print(f"  - {benchmark}")
                print(f"    Graph types: {', '.join(sorted(graph_types))}")
        exit(0)

    # Process 'all' option for encodings
    if "all" in args.encodings:
        args.encodings = ["adjacency", "incident", "expert"]

    # Run all existing prompts by default if no specific configuration is provided
    if args.run_all_prompts or (
        args.pattern is None and 
        args.example_sizes is None and 
        args.question_type == "full_output" and 
        args.target == "output" and 
        args.system_prompt is None
    ):
        run_all_prompts(
            benchmarks_to_run=args.benchmarks,
            graph_types=args.graph_types,
            model_backend=args.model_backend,
        )
    else:
        # Run with specific prompt configuration
        run_benchmarks(
            benchmarks_to_run=args.benchmarks,
            graph_types=args.graph_types,
            encoding_types=args.encodings,
            pattern=args.pattern,
            sizes=args.example_sizes,
            system_prompt=args.system_prompt,
            question_type=args.question_type,
            target=args.target,
            model_backend=args.model_backend,
            generate_missing=args.generate_missing,
        )
