import argparse
import concurrent.futures
import os
from typing import Dict, Any, Tuple, Optional

from scripts.utils.graph_validation import generate_valid_graph
from scripts.utils.task_definition import get_task, list_tasks, get_task_metadata
from scripts.utils.generate_data import generate_and_process_graph
from scripts.utils.task_compatibility import (
    find_compatible_generators_for_task,
    get_generator_function,
    GRAPH_GENERATORS_FUNCTIONS,
)

# Import task modules to register all tasks
import scripts.tasks.color_tasks  # pylint: disable=unused-import
import scripts.tasks.structure_tasks  # pylint: disable=unused-import


def import_task_modules():
    """
    Import all task modules to ensure tasks are registered.
    Returns a list of available task names.
    """
    # Task modules are imported at the top of the file
    return list_tasks()


def generate_graph_for_task(
    task_name: str,
    generator_name: str,
    num_nodes: int,
    seed: int,
    num: int,
    task_params: Dict[str, Any] = None,
) -> Tuple[str, str, int, int, bool, Optional[str]]:
    """
    Generate a single benchmark with specific graph type and size.
    Now with better error reporting.

    Returns:
    - Tuple of (task_name, generator_name, num_nodes, num, success_flag, error_message)
    """
    error_msg = None  # Track detailed error message

    try:
        # Get the task
        task = get_task(task_name)
        if not task:
            error_msg = f"Unknown task: {task_name}"
            return task_name, generator_name, num_nodes, num, False, error_msg

        # Get task metadata
        task_metadata = get_task_metadata(task_name)

        # Get generator function
        generator_function = get_generator_function(generator_name)

        # Set up task parameters
        params = task_params or {}

        # Generate a valid graph that satisfies the task's required properties
        try:
            # Use our validation function to ensure we get a compatible graph
            valid_graph_result = generate_valid_graph(
                generator_function,
                task.required_properties,
                required_pretransforms=task.required_pretransforms,
                num_nodes=num_nodes,
                seed=seed,
                timeout_seconds=30,
                debug_mode=True,  # Enable debug mode for better error reporting
            )

            # Handle the case where the generator returns extra values
            if isinstance(valid_graph_result, tuple):
                graph = valid_graph_result[0]
                extra_values = valid_graph_result[1:]
            else:
                graph = valid_graph_result
                extra_values = []

            # The graph returned by generate_valid_graph already has pre-transformations applied
            input_graph = graph

            # Create a custom transformation function that ONLY does the main transformation
            # (without applying pre-transformations again)
            def main_transform_only(G, *_):
                # Only apply the main transformation (pre-transforms already done)
                return task.transformation(G.copy(), params)

            # Create a generator function that returns our pre-processed input graph
            def input_graph_generator(*_, **__):
                if extra_values:
                    return (input_graph,) + tuple(extra_values)
                return input_graph

            # Generate the benchmark data with metadata
            generate_and_process_graph(
                name=f"{generator_name}_{task_name}",
                graph_generator=input_graph_generator,
                transformation=main_transform_only,
                num_nodes=num_nodes,
                seed=seed,
                num=num,
                visualization_layouts=["kamada_kawai", "spring"],
                task_description=task_metadata["description"],
                property_requirements=task_metadata["property_requirements"],
                transformation_parameters=params,
            )
            return task_name, generator_name, num_nodes, num, True, None

        except (TimeoutError, ValueError) as e:
            error_msg = str(e)
            print(
                f"⚠️ Failed to generate valid graph for {task_name} with {generator_name}: {error_msg}"
            )
            return task_name, generator_name, num_nodes, num, False, error_msg

    except (ValueError, TypeError, RuntimeError, KeyError, IndexError) as e:
        error_msg = str(e)
        print(
            f"❌ Error in {task_name}/{generator_name} with {num_nodes} nodes, iteration {num}: {error_msg}"
        )
        return task_name, generator_name, num_nodes, num, False, error_msg


def generate_graphs(
    n_pairs=3,
    node_sizes=None,
    seed=42,
    parallel=True,
    max_workers=None,
    tasks=None,
    graph_types=None,
    task_params=None,
):
    """
    Generates graph-based ARC datasets with various graph types and sizes.
    Now with improved error reporting.

    Parameters:
    - n_pairs (int): Number of input-output examples per benchmark.
    - node_sizes (list): List of node sizes to generate for each benchmark.
    - seed (int): Base random seed for reproducibility.
    - parallel (bool): Whether to use parallel processing.
    - max_workers (int): Maximum number of worker processes (None = auto).
    - tasks (list): Optional list of specific task names to run.
    - graph_types (list): Optional list of specific graph types to use.
    - task_params (dict): Optional parameters for tasks, keyed by task name.
    """
    # Default node sizes if none provided
    if node_sizes is None:
        node_sizes = [5, 10, 15]  # Default node counts

    # Get available tasks
    available_tasks = import_task_modules()
    print(f"Found {len(available_tasks)} registered tasks")

    # Filter tasks if specified
    selected_tasks = tasks if tasks else available_tasks
    print(f"Selected {len(selected_tasks)} tasks: {', '.join(selected_tasks)}")

    # Build list of jobs to run
    jobs = []

    for task_name in selected_tasks:
        task = get_task(task_name)
        if not task:
            print(f"Unknown task: {task_name}, skipping")
            continue

        # Get compatible graph types for this task
        compatible_generators = find_compatible_generators_for_task(task_name)

        # Filter by requested graph types if specified
        if graph_types:
            compatible_generators = [
                g for g in compatible_generators if g in graph_types
            ]

        if not compatible_generators:
            print(f"No compatible graph types found for {task_name}, skipping")
            continue

        # Get task parameters if provided
        params = task_params.get(task_name, {}) if task_params else {}

        print(
            f"Task {task_name} will be generated with graph types: {', '.join(compatible_generators)}"
        )

        # Generate for each compatible graph type and node size
        for generator_name in compatible_generators:
            for node_size in node_sizes:
                for i in range(1, n_pairs + 2):  # +2 for test case
                    jobs.append(
                        (task_name, generator_name, node_size, seed + i, i, params)
                    )

    if not jobs:
        print("No matching task-graph type combinations found.")
        return

    print(f"Prepared {len(jobs)} generation jobs")

    # Group jobs by size for better load balancing
    jobs.sort(key=lambda x: x[2])  # Sort by num_nodes

    # Track generation issues
    success_count = 0
    failure_count = 0
    failures = []

    if not parallel:
        # Sequential execution
        for job in jobs:
            result = generate_graph_for_task(*job)
            task_name, generator_name, node_size, num, success, error_msg = result
            if success:
                success_count += 1
            else:
                failure_count += 1
                failures.append((task_name, generator_name, node_size, num, error_msg))
    else:
        # Parallel execution
        if max_workers is None:
            max_workers = max(1, os.cpu_count() - 1)

        print(f"Running in parallel with {max_workers} workers")

        with concurrent.futures.ProcessPoolExecutor(
            max_workers=max_workers
        ) as executor:
            futures = [executor.submit(generate_graph_for_task, *job) for job in jobs]

            for future in concurrent.futures.as_completed(futures):
                try:
                    result = future.result()
                    task_name, generator_name, node_size, num, success, error_msg = (
                        result
                    )
                    if success:
                        success_count += 1
                    else:
                        failure_count += 1
                        failures.append(
                            (task_name, generator_name, node_size, num, error_msg)
                        )
                except (TimeoutError, ValueError, RuntimeError) as e:
                    print(f"❌ Error during graph generation: {str(e)}")
                    failure_count += 1

    # Print summary
    print("\n" + "=" * 70)
    print("GRAPH GENERATION SUMMARY:")
    print(f"- Total jobs: {len(jobs)}")
    print(f"- Successful: {success_count}")
    print(f"- Failed: {failure_count}")

    if failures:
        print("\nFailed tasks:")
        # Group failures by error type
        failure_types = {}
        for task_name, generator_name, node_size, num, error_msg in failures:
            if error_msg not in failure_types:
                failure_types[error_msg] = []
            failure_types[error_msg].append(
                f"{task_name}/{generator_name} with {node_size} nodes, iteration {num}"
            )

        # Print failures grouped by error type
        for error_msg, failed_tasks in failure_types.items():
            print(f"\nError: {error_msg}")
            for task in failed_tasks[:5]:  # Limit to 5 examples per error type
                print(f"  - {task}")
            if len(failed_tasks) > 5:
                print(f"  ... and {len(failed_tasks) - 5} more similar failures")

    print("=" * 70 + "\n")


if __name__ == "__main__":
    # Create the argument parser
    parser = argparse.ArgumentParser(
        description="Generate graph datasets with various configurations"
    )

    parser.add_argument(
        "--n_pairs",
        type=int,
        default=3,
        help="Number of input-output examples per benchmark (default: 3)",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Base random seed (default: 42)"
    )
    parser.add_argument(
        "--node_sizes",
        type=int,
        nargs="+",
        default=[5, 10, 15],
        help="Node sizes to generate (default: 5, 10, 15)",
    )
    parser.add_argument(
        "--sequential",
        action="store_true",
        help="Run sequentially instead of in parallel",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=None,
        help="Number of worker processes (default: CPU count - 1)",
    )
    parser.add_argument(
        "--tasks", nargs="+", help="Specific tasks to run (default: all)"
    )
    parser.add_argument(
        "--graph_types",
        nargs="+",
        help="Specific graph types to use (default: all compatible types)",
    )
    parser.add_argument(
        "--list_tasks", action="store_true", help="List all available tasks and exit"
    )
    parser.add_argument(
        "--list_graph_types",
        action="store_true",
        help="List all available graph types and exit",
    )

    args = parser.parse_args()

    # Import tasks first to register them
    import_task_modules()

    # List tasks if requested
    if args.list_tasks:
        print("\nAvailable tasks:")
        for task_name in sorted(list_tasks()):
            task = get_task(task_name)
            compatible_generators = find_compatible_generators_for_task(task_name)
            print(f"{task_name}:")
            print(f"  Description: {task.description}")
            print(f"  Required properties: {task.required_properties}")
            print(f"  Compatible generators: {', '.join(compatible_generators)}")
        exit(0)

    # List graph types if requested
    if args.list_graph_types:
        print("\nAvailable graph generators:")
        for gen_name in sorted(GRAPH_GENERATORS_FUNCTIONS.keys()):
            print(f"{gen_name}")
        exit(0)

    # Run the graph generation
    generate_graphs(
        n_pairs=args.n_pairs,
        node_sizes=args.node_sizes,
        seed=args.seed,
        parallel=not args.sequential,
        max_workers=args.workers,
        tasks=args.tasks,
        graph_types=args.graph_types,
    )
