import argparse
import os
import glob
import time

from scripts.utils.metadata import extract_prompt_metadata
from scripts.utils.assorted import load_benchmark_names
from scripts.generate_prompts import (
    generate_enhanced_prompt,
    get_prompt_filename,
    format_size_pattern,
    SIZE_PATTERNS,
)
from api_models.openai.query_openai import (
    check_batch_status,
    retrieve_batch_results,
    solve_graphs_batch,
)


def batch_run_all_prompts(
    benchmarks_to_run="all",
    graph_types=None,
    model="gpt-4o-mini-batch-api",
    wait_for_completion=False,
    batch_name=None,
):
    """
    Runs all existing prompts found in the specified benchmarks and graph types using batch API.

    Parameters:
    - benchmarks_to_run: List of benchmark names or "all"
    - graph_types: List of graph types or None for all
    - model: Model to use with batch API
    - wait_for_completion: Whether to wait for the batch to complete
    - batch_name: Custom name for the batch

    Returns:
    - str: Batch ID
    """
    # Ensure we're using the batch API
    if not model.endswith("-batch-api"):
        model = f"{model}-batch-api"

    if benchmarks_to_run == "all":
        benchmarks_to_run = load_benchmark_names()

    print(f"Scanning for existing prompts in benchmarks: {benchmarks_to_run}")

    # Find all existing prompts
    prompt_list = []

    for benchmark in benchmarks_to_run:
        benchmark_dir = f"datasets/{benchmark}"
        if not os.path.exists(benchmark_dir):
            continue

        available_graph_types = [
            d
            for d in os.listdir(benchmark_dir)
            if os.path.isdir(os.path.join(benchmark_dir, d))
        ]

        # Filter graph types if specified
        selected_graph_types = (
            available_graph_types
            if graph_types is None
            else [gt for gt in available_graph_types if gt in graph_types]
        )

        for graph_type in selected_graph_types:
            prompts_dir = f"{benchmark_dir}/{graph_type}/prompts"
            if not os.path.exists(prompts_dir):
                continue

            # Find all prompt files
            prompt_files = glob.glob(f"{prompts_dir}/*.txt")

            for prompt_path in prompt_files:
                metadata = extract_prompt_metadata(prompt_path)
                if metadata:
                    prompt_list.append(
                        {
                            "benchmark": benchmark,
                            "graph_type": graph_type,
                            "prompt_path": prompt_path,
                            "metadata": metadata,
                        }
                    )

    print(f"Found {len(prompt_list)} existing prompts to run")

    # Prepare a custom batch name if not specified
    if not batch_name:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        benchmarks_str = "_".join(b[:3] for b in benchmarks_to_run[:3])
        if len(benchmarks_to_run) > 3:
            benchmarks_str += f"_p{len(benchmarks_to_run)}"
        batch_name = f"batch_all_prompts_{benchmarks_str}_{timestamp}"

    # Collect all prompts and output paths
    all_prompts = []
    all_output_paths = []
    task_metadata = []

    # Group prompt metadata for summary
    encodings_used = {}
    patterns_used = {}
    system_prompts_used = {}
    question_types_used = {}
    targets_used = {}

    for prompt_info in prompt_list:
        benchmark = prompt_info["benchmark"]
        graph_type = prompt_info["graph_type"]
        prompt_path = prompt_info["prompt_path"]
        metadata = prompt_info["metadata"]

        encoding = metadata["encoding"]
        size_pattern = metadata["size_pattern"]
        system_prompt = metadata["system_prompt"]
        question_type = metadata["question_type"]
        target = metadata["target"]

        # Track metadata for summary
        encodings_used[encoding] = encodings_used.get(encoding, 0) + 1
        patterns_used[size_pattern] = patterns_used.get(size_pattern, 0) + 1
        system_prompts_used[system_prompt] = (
            system_prompts_used.get(system_prompt, 0) + 1
        )
        question_types_used[question_type] = (
            question_types_used.get(question_type, 0) + 1
        )
        targets_used[target] = targets_used.get(target, 0) + 1

        # Construct response path
        response_dir = f"datasets/{benchmark}/{graph_type}/responses"
        os.makedirs(response_dir, exist_ok=True)

        # Use model name without batch-api suffix for the response filename
        model_name = model.replace("-batch-api", "")

        # Generate response filename that includes question type and target, using hyphen-separated format
        if question_type == "full_output":
            response_filename = (
                f"{encoding}-{size_pattern}-{system_prompt}-{model_name}.txt"
            )
        else:
            response_filename = f"{encoding}-{size_pattern}-{system_prompt}-{question_type}-{target}-{model_name}.txt"

        response_path = os.path.join(response_dir, response_filename)

        # Check if response already exists
        if os.path.exists(response_path):
            print(f"Response already exists at {response_path}, skipping.")
            continue

        # Load the prompt
        try:
            with open(prompt_path, "r", encoding="utf-8") as f:
                prompt_content = f.read()

            # Add to batch collection
            all_prompts.append(prompt_content)
            all_output_paths.append(response_path)
            task_metadata.append(
                {
                    "benchmark": benchmark,
                    "graph_type": graph_type,
                    "encoding": encoding,
                    "size_pattern": size_pattern,
                    "system_prompt": system_prompt,
                    "question_type": question_type,
                    "target": target,
                    "output_path": response_path,
                }
            )

        except (IOError, OSError) as e:
            print(f"❌ Error reading prompt file {prompt_path}: {e}")

    if not all_prompts:
        print("⚠️ No valid prompts found (or all responses already exist).")
        return None

    # Print batch summary
    print("\n📊 Batch Summary:")
    print(f"- Total prompts: {len(all_prompts)}")
    print(
        f"- Benchmarks: {len(set(info['benchmark'] for info in task_metadata))} unique"
    )
    print(f"- Encodings: {encodings_used}")
    print(f"- Patterns: {len(patterns_used)} unique")
    print(f"- System prompts: {system_prompts_used}")
    print(f"- Question types: {question_types_used}")
    print(f"- Targets: {targets_used}")

    # Submit the batch
    base_model = model.replace("-batch-api", "")
    print(
        f"\nSubmitting batch with {len(all_prompts)} prompts using model {base_model}..."
    )

    # Save batch metadata
    os.makedirs("batch_jobs", exist_ok=True)
    batch_summary_path = f"batch_jobs/{batch_name}_summary.txt"

    with open(batch_summary_path, "w", encoding="utf-8") as f:
        f.write(f"Batch: {batch_name}\n")
        f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Model: {base_model}\n")
        f.write(f"Total prompts: {len(all_prompts)}\n\n")

        f.write("Benchmarks:\n")
        benchmark_counts = {}
        for task in task_metadata:
            benchmark = task["benchmark"]
            benchmark_counts[benchmark] = benchmark_counts.get(benchmark, 0) + 1
        for benchmark, count in benchmark_counts.items():
            f.write(f"  - {benchmark}: {count} prompts\n")

        f.write("\nEncodings:\n")
        for encoding, count in encodings_used.items():
            f.write(f"  - {encoding}: {count} prompts\n")

        f.write("\nPatterns:\n")
        for pattern, count in patterns_used.items():
            f.write(f"  - {pattern}: {count} prompts\n")

        f.write("\nSystem prompts:\n")
        for system_prompt, count in system_prompts_used.items():
            f.write(f"  - {system_prompt}: {count} prompts\n")

        f.write("\nQuestion types:\n")
        for question_type, count in question_types_used.items():
            f.write(f"  - {question_type}: {count} prompts\n")

        f.write("\nTargets:\n")
        for target, count in targets_used.items():
            f.write(f"  - {target}: {count} prompts\n")

    # Submit the batch
    batch_id = solve_graphs_batch(
        prompts=all_prompts,
        output_paths=all_output_paths,
        model=model,
        batch_name=batch_name,
    )

    # Update summary with batch ID and correct retrieval commands
    with open(batch_summary_path, "a", encoding="utf-8") as f:
        f.write(f"\nBatch ID: {batch_id}\n")
        f.write(f"Check status: python -m scripts.batch_run_tasks --check {batch_id}\n")
        f.write(
            f"Retrieve results: python -m scripts.batch_run_tasks --retrieve {batch_id}\n"
        )

    # Optionally wait for completion
    if wait_for_completion:
        print("\nWaiting for batch to complete (this may take a while)...")

        # Check status periodically
        while True:
            status = check_batch_status(batch_id)
            if status["is_complete"]:
                print("\n✅ Batch completed! Retrieving results...")
                retrieve_batch_results(batch_id)
                break

            # Progress update
            print(f"Progress: {status['progress']} - waiting 30 seconds...")
            time.sleep(30)

    return batch_id


def batch_run_benchmarks(
    benchmarks_to_run="all",
    graph_types=None,
    encoding_types=None,
    pattern=None,
    sizes=None,
    system_prompt=None,
    question_type="full_output",
    target="output",
    model="gpt-4o-mini-batch-api",
    wait_for_completion=False,
    batch_name=None,
    generate_missing=False,
):
    """
    Runs multiple benchmark evaluations for graph transformation tasks in a single batch.

    Parameters:
    - benchmarks_to_run (list or "all"): List of benchmark names to run, or "all".
    - graph_types (list or None): List of graph types to use, or None for all.
    - encoding_types (list or None): List of encoding types to use.
    - pattern (str): Size pattern name (e.g., "scale_up_3").
    - sizes (list[int]): Specific sizes for examples and test.
    - system_prompt (str): System prompt type to use.
    - question_type (str): Type of question to ask (default: "full_output").
    - target (str): Target for the question - "input" or "output" (default: "output").
    - model (str): OpenAI model to use (should end with '-batch-api').
    - wait_for_completion (bool): Whether to wait for the batch to complete.
    - batch_name (str, optional): Custom name for the batch.
    - generate_missing (bool): Whether to generate missing prompts on-the-fly.

    Returns:
    - str: Batch ID for checking status and retrieving results later.
    """
    # Default values if None
    if encoding_types is None:
        encoding_types = ["adjacency"]

    # Ensure we're using the batch API
    if not model.endswith("-batch-api"):
        model = f"{model}-batch-api"

    # Determine sizes to use (for filename lookup)
    actual_sizes = None
    if sizes:
        actual_sizes = sizes
    elif pattern:
        actual_sizes = SIZE_PATTERNS.get(pattern)

    if not actual_sizes:
        # Default fallback
        actual_sizes = [10, 10]

    # Get list of benchmarks to run
    if benchmarks_to_run == "all":
        benchmarks_to_run = load_benchmark_names()

    # Prepare a custom batch name if not specified
    if not batch_name:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        encodings_str = "_".join(
            e[0] for e in encoding_types
        )  # first letter of each encoding
        benchmarks_str = "_".join(
            b[:3] for b in benchmarks_to_run[:3]
        )  # first 3 chars of up to 3 benchmarks
        if len(benchmarks_to_run) > 3:
            benchmarks_str += (
                f"_p{len(benchmarks_to_run)}"  # indicate total count if more than 3
            )
        pattern_str = pattern or format_size_pattern(actual_sizes)
        question_str = (
            f"{question_type}_{target}" if question_type != "full_output" else "full"
        )
        batch_name = f"batch_{benchmarks_str}_{encodings_str}_{pattern_str}_{question_str}_{timestamp}"

    print(f"Preparing batch '{batch_name}' for benchmarks: {benchmarks_to_run}")
    print(f"Encodings: {encoding_types}")
    print(f"Example pattern: {pattern or 'custom'} with sizes {actual_sizes}")
    print(f"Question: {question_type} (target: {target})")

    # Collect all prompts and output paths
    all_prompts = []
    all_output_paths = []
    task_metadata = []  # For logging purposes

    for benchmark in benchmarks_to_run:
        # Find available graph types for this benchmark
        available_graph_types = []
        benchmark_dir = f"datasets/{benchmark}"
        if os.path.exists(benchmark_dir):
            available_graph_types = [
                d
                for d in os.listdir(benchmark_dir)
                if os.path.isdir(os.path.join(benchmark_dir, d))
            ]

        # Filter graph types if specified
        selected_graph_types = available_graph_types
        if graph_types:
            selected_graph_types = [
                gt for gt in available_graph_types if gt in graph_types
            ]

        if not selected_graph_types:
            print(f"No matching graph types found for benchmark {benchmark}, skipping.")
            continue

        for graph_type in selected_graph_types:
            for encoding_type in encoding_types:
                # Generate the prompt filename based on pattern or sizes
                filename = get_prompt_filename(
                    encoding_type,
                    actual_sizes,
                    system_prompt or "none",
                    question_type,
                    target,
                )
                base_dir = f"datasets/{benchmark}/{graph_type}"
                prompts_dir = f"{base_dir}/prompts"
                responses_dir = f"{base_dir}/responses"

                # Make sure directories exist
                os.makedirs(prompts_dir, exist_ok=True)
                os.makedirs(responses_dir, exist_ok=True)

                prompt_path = os.path.join(prompts_dir, filename)

                # Check if the prompt file exists
                if not os.path.exists(prompt_path):
                    if generate_missing:
                        print(f"Generating missing prompt: {prompt_path}")
                        # Generate the prompt on-the-fly
                        prompt_content = generate_enhanced_prompt(
                            directory=f"{benchmark}/{graph_type}",
                            encoding=encoding_type,
                            sizes=actual_sizes,
                            system_prompt_type=system_prompt or "none",
                            question_type=question_type,
                            target=target,
                        )

                        if not prompt_content:
                            print(
                                f"❌ Failed to generate prompt for {benchmark}/{graph_type}, skipping."
                            )
                            continue

                        # Save the generated prompt
                        with open(prompt_path, "w", encoding="utf-8") as f:
                            f.write(prompt_content)
                    else:
                        print(
                            f"⚠️ Prompt not found: {prompt_path}, skipping. Use --generate_missing to create it."
                        )
                        continue

                # Load the prompt
                try:
                    with open(prompt_path, "r", encoding="utf-8") as f:
                        prompt_content = f.read()
                except (IOError, OSError) as e:
                    print(f"❌ Error reading prompt file {prompt_path}: {e}")
                    continue

                task_name = f"{benchmark}/{graph_type}/{encoding_type}"
                print(f"Adding {task_name} to batch...")

                # Construct response filename
                size_pattern_str = format_size_pattern(actual_sizes)
                system_str = system_prompt or "none"

                # Construct response filename using hyphen-separated format
                base_model_name = model.replace("-batch-api", "")
                if question_type == "full_output":
                    response_filename = f"{encoding_type}-{size_pattern_str}-{system_str}-{base_model_name}.txt"
                else:
                    response_filename = f"{encoding_type}-{size_pattern_str}-{system_str}-{question_type}-{target}-{base_model_name}.txt"

        response_path = os.path.join(responses_dir, response_filename)

        # Check if response already exists
        if os.path.exists(response_path):
            print(f"Response already exists at {response_path}, skipping.")
            continue

        # Add to batch lists
        all_prompts.append(prompt_content)
        all_output_paths.append(response_path)
        task_metadata.append(
            {
                "benchmark": benchmark,
                "graph_type": graph_type,
                "encoding": encoding_type,
                "pattern": pattern or format_size_pattern(actual_sizes),
                "system_prompt": system_prompt,
                "question_type": question_type,
                "target": target,
                "output_path": response_path,
            }
        )

    if not all_prompts:
        print("⚠️ No valid benchmark tasks found with the specified criteria.")
        return None

    # Generate a batch summary
    tasks_by_benchmark = {}
    tasks_by_encoding = {}
    tasks_by_pattern = {}
    tasks_by_question = {}

    for task in task_metadata:
        benchmark = task["benchmark"]
        encoding = task["encoding"]
        pattern = task["pattern"]
        question_key = f"{task['question_type']}({task['target']})"

        tasks_by_benchmark[benchmark] = tasks_by_benchmark.get(benchmark, 0) + 1
        tasks_by_encoding[encoding] = tasks_by_encoding.get(encoding, 0) + 1
        tasks_by_pattern[pattern] = tasks_by_pattern.get(pattern, 0) + 1
        tasks_by_question[question_key] = tasks_by_question.get(question_key, 0) + 1

    print("\n📊 Batch Summary:")
    print(f"- Total tasks: {len(all_prompts)}")
    print(f"- Benchmarks: {dict(tasks_by_benchmark)}")
    print(f"- Encodings: {dict(tasks_by_encoding)}")
    print(f"- Patterns: {dict(tasks_by_pattern)}")
    print(f"- Questions: {dict(tasks_by_question)}")

    # Submit the batch
    base_model = model.replace("-batch-api", "")
    print(
        f"\nSubmitting batch with {len(all_prompts)} tasks using model {base_model}..."
    )

    # Save batch metadata for reference
    os.makedirs("batch_jobs", exist_ok=True)
    batch_summary_path = f"batch_jobs/{batch_name}_summary.txt"

    with open(batch_summary_path, "w", encoding="utf-8") as f:
        f.write(f"Batch: {batch_name}\n")
        f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Model: {base_model}\n")
        f.write(f"Total tasks: {len(all_prompts)}\n\n")

        f.write("Benchmarks:\n")
        for benchmark, count in tasks_by_benchmark.items():
            f.write(f"  - {benchmark}: {count} tasks\n")

        f.write("\nEncodings:\n")
        for encoding, count in tasks_by_encoding.items():
            f.write(f"  - {encoding}: {count} tasks\n")

        f.write("\nPatterns:\n")
        for pattern, count in tasks_by_pattern.items():
            f.write(f"  - {pattern}: {count} tasks\n")

        f.write("\nQuestions:\n")
        for question, count in tasks_by_question.items():
            f.write(f"  - {question}: {count} tasks\n")

        f.write(f"\nSizes: {actual_sizes}\n")

    print(f"Batch summary saved to {batch_summary_path}")

    # Submit the batch with our custom name
    batch_id = solve_graphs_batch(
        prompts=all_prompts,
        output_paths=all_output_paths,
        model=model,
        batch_name=batch_name,
    )

    # Update summary with batch ID and correct retrieval commands
    with open(batch_summary_path, "a", encoding="utf-8") as f:
        f.write(f"\nBatch ID: {batch_id}\n")
        f.write(f"Check status: python -m scripts.batch_run_tasks --check {batch_id}\n")
        f.write(
            f"Retrieve results: python -m scripts.batch_run_tasks --retrieve {batch_id}\n"
        )

    # Optionally wait for completion
    if wait_for_completion:
        print("\nWaiting for batch to complete (this may take a while)...")

        # Check status periodically
        while True:
            status = check_batch_status(batch_id)
            if status["is_complete"]:
                print("\n✅ Batch completed! Retrieving results...")
                retrieve_batch_results(batch_id)
                break

            # Progress update
            print(f"Progress: {status['progress']} - waiting 30 seconds...")
            time.sleep(30)

    return batch_id


def list_available_benchmarks():
    """Lists all available benchmarks and their graph types."""
    benchmarks = load_benchmark_names()

    print("\n📋 Available Benchmarks:")
    for benchmark in sorted(benchmarks):
        benchmark_dir = f"datasets/{benchmark}"
        if os.path.exists(benchmark_dir):
            graph_types = [
                d
                for d in os.listdir(benchmark_dir)
                if os.path.isdir(os.path.join(benchmark_dir, d))
            ]
            print(f"  - {benchmark}")
            print(f"    Graph types: {', '.join(sorted(graph_types))}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run benchmark evaluations in batch")

    parser.add_argument(
        "--benchmarks",
        nargs="+",
        default="all",
        help="Specific benchmarks to run (default: all)",
    )
    parser.add_argument(
        "--graph_types",
        nargs="+",
        default=None,
        help="Specific graph types to use (default: all available)",
    )
    parser.add_argument(
        "--encodings",
        nargs="+",
        default=["adjacency"],
        choices=["adjacency", "incident", "expert", "all"],
        help="Encoding types to use (default: adjacency, use 'all' for all encodings)",
    )
    parser.add_argument(
        "--example_sizes",
        nargs="+",
        type=int,
        default=None,
        help="Specific sizes for examples (e.g., 5 5 15)",
    )
    parser.add_argument(
        "--pattern",
        type=str,
        default=None,
        choices=list(SIZE_PATTERNS.keys()),
        help="Size pattern to use (e.g., scale_up_3)",
    )
    parser.add_argument(
        "--system_prompt",
        type=str,
        default=None,
        choices=["analyst", "programmer", "teacher", "none"],
        help="System prompt type to prepend",
    )
    parser.add_argument(
        "--question_type",
        type=str,
        default="full_output",
        help="Type of question to ask (default: full_output)",
    )
    parser.add_argument(
        "--target",
        type=str,
        default="output",
        choices=["input", "output"],
        help="Target for the question - input or output graph (default: output)",
    )
    parser.add_argument(
        "--model",
        default="o3-mini-batch-api",
        choices=[
            "gpt-4o-mini-batch-api",
            "o3-mini-batch-api",
            "o4-mini-batch-api",
            "gpt-4.1-nano-batch-api",
            "gpt-4.1-mini-batch-api",
            "o3-batch-api",
        ],
        help="Model to use with Batch API (default: o3-mini-batch-api)",
    )
    parser.add_argument(
        "--generate_missing",
        action="store_true",
        help="Generate missing prompts on-the-fly",
    )
    parser.add_argument(
        "--run_all_prompts",
        action="store_true",
        help="Run all existing prompts for the specified benchmarks",
    )
    parser.add_argument(
        "--wait",
        action="store_true",
        help="Wait for batch to complete and retrieve results",
    )
    parser.add_argument(
        "--check",
        type=str,
        default=None,
        help="Check status of a previously submitted batch",
    )
    parser.add_argument(
        "--retrieve",
        type=str,
        default=None,
        help="Retrieve results from a completed batch",
    )
    parser.add_argument(
        "--list",
        action="store_true",
        help="List all available benchmarks and graph types",
    )
    parser.add_argument(
        "--batch_name",
        type=str,
        default=None,
        help="Custom name for the batch (default: auto-generated)",
    )
    parser.add_argument(
        "--list_patterns",
        action="store_true",
        help="List available size patterns",
    )

    args = parser.parse_args()

    # Handle special commands first
    if args.list:
        list_available_benchmarks()
        exit(0)

    if args.list_patterns:
        print("Available size patterns:")
        for name, sizes in SIZE_PATTERNS.items():
            print(f"  {name}: {sizes}")
        exit(0)

    if args.check:
        print(f"Checking status of batch {args.check}...")
        check_batch_status(args.check)
        exit(0)

    if args.retrieve:
        print(f"Retrieving results from batch {args.retrieve}...")
        retrieve_batch_results(args.retrieve)
        exit(0)

    # Process 'all' option for encodings
    if "all" in args.encodings:
        args.encodings = ["adjacency", "incident", "expert"]

    # Run all existing prompts by default if no specific configuration is provided
    if args.run_all_prompts or (
        args.pattern is None
        and args.example_sizes is None
        and args.question_type == "full_output"
        and args.target == "output"
        and args.system_prompt is None
    ):
        batch_id_to_return = batch_run_all_prompts(
            benchmarks_to_run=args.benchmarks,
            graph_types=args.graph_types,
            model=args.model,
            wait_for_completion=args.wait,
            batch_name=args.batch_name,
        )
    else:
        # Run the benchmarks with specified configuration (original code)
        batch_id_to_return = batch_run_benchmarks(
            benchmarks_to_run=args.benchmarks,
            graph_types=args.graph_types,
            encoding_types=args.encodings,
            pattern=args.pattern,
            sizes=args.example_sizes,
            system_prompt=args.system_prompt,
            question_type=args.question_type,
            target=args.target,
            model=args.model,
            wait_for_completion=args.wait,
            batch_name=args.batch_name,
            generate_missing=args.generate_missing,
        )

    if batch_id_to_return:
        print(f"\nBatch ID: {batch_id_to_return}")
        print("Save this ID to check status or retrieve results later")
