import networkx as nx
from typing import Dict, Any

from scripts.utils.task_definition import register_task
from scripts.utils.properties import (
    has_colored_leaves,
    has_degree,
    has_components,
    has_colored_nodes,
    has_equidistant_node,
)


@register_task(
    name="colorDegree1",
    required_properties=[has_degree(1)],
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    parameter_schema={"color": "string"},
    description="Colors all nodes with degree 1.",
)
def color_degree_1(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes with degree 1 blue.

    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Graph with degree 1 nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    color = params.get("color", "blue")

    for node in G.nodes:
        if G.degree[node] == 1:
            G.nodes[node]["color"] = color

    return G


@register_task(
    name="colorDegree2",
    required_properties=[has_degree(2)],
    preferred_generators=["random", "randomConnected", "randomTree"],
    parameter_schema={"color": "string"},
    description="Colors all nodes with degree 2.",
)
def color_degree_2(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes with degree 2 blue.

    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Graph with degree 2 nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    color = params.get("color", "blue")

    for node in G.nodes:
        if G.degree[node] == 2:
            G.nodes[node]["color"] = color

    return G


@register_task(
    name="colorDegree3",
    required_properties=[has_degree(3)],
    preferred_generators=["random", "randomConnected", "randomTree"],
    parameter_schema={"color": "string"},
    description="Colors all nodes with degree 3.",
)
def color_degree_3(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes with degree 3 blue.

    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Graph with degree 3 nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    color = params.get("color", "blue")

    for node in G.nodes:
        if G.degree[node] == 3:
            G.nodes[node]["color"] = color

    return G


@register_task(
    name="colorMaxDegree",
    preferred_generators=["random", "randomConnected", "randomTree"],
    parameter_schema={"color": "string"},
    description="Colors all nodes with maximum degree.",
)
def color_max_degree(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes with the maximum degree.

    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Graph with maximum degree nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    color = params.get("color", "blue")

    degrees = dict(G.degree())
    max_degree = max(degrees.values())

    for node, deg in degrees.items():
        if deg == max_degree:
            G.nodes[node]["color"] = color

    return G


@register_task(
    name="colorMinDegree",
    preferred_generators=["random", "randomConnected", "randomTree"],
    parameter_schema={"color": "string"},
    description="Colors all nodes with minimum degree.",
)
def color_min_degree(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes with the minimum degree.

    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Graph with minimum degree nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    color = params.get("color", "blue")

    degrees = dict(G.degree())
    min_degree = min(degrees.values())

    for node, deg in degrees.items():
        if deg == min_degree:
            G.nodes[node]["color"] = color

    return G


@register_task(
    name="colorInternal",
    required_properties=["has_internal_node"],
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    parameter_schema={"color": "string"},
    description="Colors all non-leaf (internal) nodes.",
)
def color_internal_nodes(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all non-leaf (internal) nodes.

    Parameters:
    - G: NetworkX graph to transform (must be a tree)
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Graph with internal nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    color = params.get("color", "blue")

    for node in G.nodes:
        if G.degree[node] > 1:  # Internal nodes have degree > 1
            G.nodes[node]["color"] = color

    return G


@register_task(
    name="colorNeighbors",
    required_properties=["has_colored_node"],
    required_pretransforms=[("color_random_node", {"color": "orange"})],
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    parameter_schema={"target_color": "string", "source_color": "string"},
    description="Colors all neighbors of an orange node.",
)
def color_neighbors(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all neighbors of a node that has a specific color.

    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameters:
      - "source_color": color of the source node (default: "orange")
      - "target_color": color to apply to neighbors (default: "blue")

    Returns:
    - Graph with neighbors colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    source_color = params.get("source_color", "orange")
    target_color = params.get("target_color", "blue")

    # Find the source node(s)
    source_nodes = [
        node for node, data in G.nodes(data=True) if data.get("color") == source_color
    ]

    if not source_nodes:
        raise ValueError(f"No nodes with color '{source_color}' found in the graph")

    # Color neighbors of each source node
    for source in source_nodes:
        for neighbor in G.neighbors(source):
            if (
                G.nodes[neighbor]["color"] != source_color
                and G.nodes[neighbor]["color"] != target_color
            ):
                # Only color if the neighbor is not already colored with source or target color
                G.nodes[neighbor]["color"] = target_color

    return G


@register_task(
    name="colorPath",
    required_properties=["connected", "acyclic", has_colored_leaves(2), has_degree(1)],
    required_pretransforms=[("color_some_leaves", {"color": "blue", "count": 2})],
    preferred_generators=["randomTree", "star"],
    parameter_schema={"color": "string", "leaf_color": "string"},
    description="Colors all nodes on the path between two colored leaves.",
)
def color_path(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes on the path between two colored leaves.

    Parameters:
    - G: NetworkX graph to transform (must be a tree)
    - params: Dictionary with parameters:
      - "color": color to apply to path nodes (default: "blue")
      - "leaf_color": color of the leaf nodes to find the path between (default: "blue")

    Returns:
    - Graph with path nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    path_color = params.get("color", "blue")
    leaf_color = params.get("leaf_color", "blue")

    # Find blue leaf nodes
    blue_leaves = [
        node
        for node, data in G.nodes(data=True)
        if data.get("color") == leaf_color and G.degree[node] == 1
    ]

    if len(blue_leaves) < 2:
        raise ValueError(f"Need exactly two {leaf_color} leaves, but found fewer")
    if len(blue_leaves) > 2:
        raise ValueError(f"Need exactly two {leaf_color} leaves, but found more")

    # Get the first two blue leaves
    leaf1, leaf2 = blue_leaves[:2]

    # Find shortest path between them
    path_nodes = nx.shortest_path(G, source=leaf1, target=leaf2)

    # Color all nodes in the path
    for node in path_nodes:
        G.nodes[node]["color"] = path_color

    return G


@register_task(
    name="colorComponents",
    required_properties=[has_components(2)],
    # Doesn't require a specific pretransformation as the graph generator
    preferred_generators=["random2Component"],
    parameter_schema={"first_color": "string", "second_color": "string"},
    description="Colors all nodes in one component and nodes in another component another color.",
)
def color_components(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors nodes based on connected component membership.

    Parameters:
    - G: NetworkX graph to transform (must have multiple components)
    - params: Dictionary with parameters:
      - "first_color": color for one component (default: "blue")
      - "second_color": color for the other component (default: "orange")

    Returns:
    - Graph with nodes colored by component

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    first_color = params.get("first_color", "blue")
    second_color = params.get("second_color", "orange")

    # Find existing colored nodes
    blue_node = None
    orange_node = None

    for node, data in G.nodes(data=True):
        if data.get("color") == first_color and blue_node is None:
            blue_node = node
        elif data.get("color") == second_color and orange_node is None:
            orange_node = node

    if blue_node is None or orange_node is None:
        raise ValueError(
            f"Graph must contain at least one {first_color} node and one {second_color} node"
        )

    # Get the connected components containing the special nodes
    component_blue = nx.node_connected_component(G, blue_node)
    component_orange = nx.node_connected_component(G, orange_node)

    # Color all nodes in the corresponding components
    for node in component_blue:
        G.nodes[node]["color"] = first_color
    for node in component_orange:
        G.nodes[node]["color"] = second_color

    return G


@register_task(
    name="colorDistanceAtLeast2",
    required_properties=["has_colored_node"],
    required_pretransforms=[("color_some_random_nodes", {"color": "blue", "count": 2})],
    preferred_generators=["random", "randomConnected", "randomTree"],
    parameter_schema={
        "source_color": "string",
        "target_color": "string",
        "min_distance": "int",
    },
    description="Colors nodes that are at least distance 2 from a marked node.",
)
def color_distance_at_least_2(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes that are at a minimum distance from a marked (colored) node.

    Parameters:
    - G: NetworkX graph with at least one colored node
    - params: Dictionary with parameters:
      - "source_color": color of the marked node(s) (default: "blue")
      - "target_color": color to apply to distant nodes (default: "blue")
      - "min_distance": minimum distance required (default: 2)

    Returns:
    - Graph with nodes at minimum distance colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    source_color = params.get("source_color", "blue")
    target_color = params.get("target_color", "blue")
    min_distance = params.get("min_distance", 2)

    # Find source nodes
    source_nodes = [
        node
        for node, data in G.nodes(data=True)
        if data.get("color", "grey") == source_color
    ]

    if not source_nodes:
        raise ValueError(f"No nodes with color '{source_color}' found in the graph")

    # For each node, compute minimum distance to any source node
    for node in G.nodes:
        if node not in source_nodes:
            # Find minimum distance to any source node
            min_dist = float("inf")
            for source in source_nodes:
                try:
                    dist = nx.shortest_path_length(G, source, node)
                    min_dist = min(min_dist, dist)
                except nx.NetworkXNoPath:
                    pass  # No path exists

            # Color if distance meets threshold
            if min_dist >= min_distance and min_dist != float("inf"):
                G.nodes[node]["color"] = target_color

    return G


@register_task(
    name="colorEquidistant",
    required_properties=[
        "connected",
        has_colored_nodes(2, "blue"),
        has_equidistant_node(),
    ],
    required_pretransforms=[("color_some_random_nodes", {"color": "blue", "count": 2})],
    preferred_generators=["randomConnected", "randomTree"],
    parameter_schema={"source_color": "string", "target_color": "string"},
    description="Colors nodes equidistant from two blue nodes.",
)
def color_equidistant(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes that are equidistant from two blue nodes.

    Parameters:
    - G: NetworkX graph with exactly two blue nodes
    - params: Dictionary with parameters:
      - "source_color": color of the two reference nodes (default: "blue")
      - "target_color": color for equidistant nodes (default: "red")

    Returns:
    - Graph with equidistant nodes colored

    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    source_color = params.get("source_color", "blue")
    target_color = params.get("target_color", "red")

    # Find the two blue nodes
    blue_nodes = [
        node
        for node, data in G.nodes(data=True)
        if data.get("color", "grey") == source_color
    ]

    if len(blue_nodes) != 2:
        raise ValueError(
            f"Expected exactly 2 {source_color} nodes, found {len(blue_nodes)}"
        )

    v, w = blue_nodes

    # Color all nodes that are equidistant from v and w; error if none found
    colored_any = False
    for node, data in G.nodes(data=True):
        if node in blue_nodes:
            continue
        try:
            dist_v = nx.shortest_path_length(G, v, node)
            dist_w = nx.shortest_path_length(G, w, node)
        except nx.NetworkXNoPath:
            continue
        if dist_v == dist_w:
            G.nodes[node]["color"] = target_color
            colored_any = True

    if not colored_any:
        raise ValueError("No equidistant nodes found to color")

    return G
