import networkx as nx
from typing import Dict, Any

from scripts.utils.task_definition import register_task
from scripts.utils.properties import has_degree, has_components, has_colored_nodes, all_nodes_colored_with


@register_task(
    name="addHub",
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    parameter_schema={"color": "string"},
    description="Adds a new colored hub node connected to all existing nodes."
)
def add_hub(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Adds a new blue-colored node connected to all existing nodes.
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "color" (default: "blue")
    
    Returns:
    - Graph with a new hub node
    
    Index Transformation:
    - Original nodes: All original node indices (0, 1, 2, ...) are preserved exactly as they were
    - New node: A new node with index max(G.nodes) + 1 is added to the graph
    - Example: If input graph has nodes [0, 1, 2, 3, 4], the output graph will have nodes 
               [0, 1, 2, 3, 4, 5], where node 5 is the new hub node
    """
    color = params.get("color", "blue")
    
    # New node ID is max existing node + 1
    new_node = max(G.nodes) + 1
    G.add_node(new_node, color=color)
    
    for node in G.nodes:
        if node != new_node:
            G.add_edge(new_node, node)
    
    return G


@register_task(
    name="edgeToNode",
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    description="Replaces every edge with a new intermediate node connected to both endpoints."
)
def edge_to_node(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph: # pylint: disable=unused-argument
    """
    Replaces every edge in the graph with a new intermediate node connected to both endpoints.
    
    Parameters:
    - G: NetworkX graph to transform
    
    Returns:
    - Graph with edges replaced by nodes
    
    Index Transformation:
    - Original nodes: All original node indices are preserved exactly as they were
    - New nodes: For each edge (u,v) in the original graph, a new node is added with 
                  an index starting from max(G.nodes) + 1 and incrementing by 1 for each edge
    - Edge ordering: The order of edge processing is deterministic and follows NetworkX's internal edge ordering:
      - Edges are ordered by the first node index, then by the second node index
      - For example, edges will be processed in this exact order: [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
      - This ordering is fully deterministic and consistent for the same graph structure
    - Processing sequence: 
      1. All edges are collected via list(G.edges) in NetworkX's deterministic order
      2. Edges are processed one by one in the exact order returned
      3. For each edge (u,v), a new node replaces it, assigned ID = max(G.nodes) + increment
    - Example: If input graph has nodes [0, 1, 2, 3] with edges [(0,1), (1,2), (2,3)], 
               the output graph will have nodes [0, 1, 2, 3, 4, 5, 6], where:
               - Node 4 replaces edge (0,1) and connects to nodes 0 and 1
               - Node 5 replaces edge (1,2) and connects to nodes 1 and 2
               - Node 6 replaces edge (2,3) and connects to nodes 2 and 3
    """
    
    new_node_id = max(G.nodes) + 1
    new_edges = []
    new_nodes = []
    
    # Store original edges before removing them
    original_edges = list(G.edges)
    
    # Remove all edges
    for u, v in original_edges:
        G.remove_edge(u, v)
        
        # Add intermediate node
        G.add_node(new_node_id)
        new_nodes.append(new_node_id)
        
        # Connect it to both original endpoints
        new_edges.append((new_node_id, u))
        new_edges.append((new_node_id, v))
        
        new_node_id += 1
    
    # Add all new edges at once
    G.add_edges_from(new_edges)
    
    return G


@register_task(
    name="removeDegree1",
    required_properties=[has_degree(1)],
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    description="Removes all nodes with degree 1 from the graph."
)
def remove_degree_1(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph: # pylint: disable=unused-argument
    """
    Removes all nodes with degree 1 from the graph.
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Not used
    
    Returns:
    - Graph with degree 1 nodes removed
    
    Index Transformation:
    - Removed nodes: All nodes with degree 1 are removed from the graph
    - Remaining nodes: All other nodes retain their original indices
    - The resulting graph may have non-consecutive node indices
    - Example: If input graph has nodes [0, 1, 2, 3, 4] and nodes 0 and 4 have degree 1,
               the output graph will have nodes [1, 2, 3] with their original indices
    """
    nodes_to_remove = [node for node in G.nodes if G.degree[node] == 1]
    G.remove_nodes_from(nodes_to_remove)
    
    return G


@register_task(
    name="removeDegree2",
    required_properties=[has_degree(2)],
    preferred_generators=["random", "randomConnected", "randomTree"],
    description="Removes all nodes with degree 2 from the graph."
)
def remove_degree_2(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph: # pylint: disable=unused-argument
    """
    Removes all nodes with degree 2 from the graph.
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Not used
    
    Returns:
    - Graph with degree 2 nodes removed
    
    Index Transformation:
    - Removed nodes: All nodes with degree 2 are removed from the graph
    - Remaining nodes: All other nodes retain their original indices
    - The resulting graph may have non-consecutive node indices
    - Example: If input graph has nodes [0, 1, 2, 3, 4] and nodes 1 and 3 have degree 2,
               the output graph will have nodes [0, 2, 4] with their original indices
    """
    nodes_to_remove = [node for node in G.nodes if G.degree[node] == 2]
    G.remove_nodes_from(nodes_to_remove)
    
    return G

@register_task(
    name="removeDegree3",
    required_properties=[has_degree(3)],
    preferred_generators=["random", "randomConnected", "randomTree"],
    description="Removes all nodes with degree 3 from the graph."
)
def remove_degree_3(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph: # pylint: disable=unused-argument
    """
    Removes all nodes with degree 3 from the graph.
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Not used
    
    Returns:
    - Graph with degree 3 nodes removed
    
    Index Transformation:
    - Removed nodes: All nodes with degree 3 are removed from the graph
    - Remaining nodes: All other nodes retain their original indices
    - The resulting graph may have non-consecutive node indices
    - Example: If input graph has nodes [0, 1, 2, 3, 4] and nodes 0 and 2 have degree 3,
               the output graph will have nodes [1, 3, 4] with their original indices
    """
    nodes_to_remove = [node for node in G.nodes if G.degree[node] == 3]
    G.remove_nodes_from(nodes_to_remove)
    
    return G


@register_task(
    name="bipartitionCompletion",
    required_properties=["bipartite"],
    required_pretransforms=[("color_bipartition_seeds", {"first_color": "blue", "second_color": "orange"})],
    preferred_generators=["bipartite"],
    parameter_schema={"first_color": "string", "second_color": "string"},
    description="Colors remaining nodes in a bipartite graph based on the colored seeds."
)
def bipartition_completion(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph: # pylint: disable=unused-argument
    """
    Colors remaining nodes in a bipartite graph based on existing seed colors.
    Assumes each partition already has one colored node with different colors.
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with optional color parameters (not actually used)
    
    Returns:
    - Graph with nodes colored by partition
    
    Index Transformation:
    - No index changes: This transformation only modifies node colors
    - All node indices remain exactly the same as in the input graph
    - The graph structure (nodes and edges) is completely preserved
    """
    # Identify partitions
    part0, part1 = nx.algorithms.bipartite.sets(G)
    
    # Find the existing color in each partition
    part0_color = next(G.nodes[node].get("color") for node in part0 
                      if G.nodes[node].get("color", "grey") != "grey")
    part1_color = next(G.nodes[node].get("color") for node in part1 
                      if G.nodes[node].get("color", "grey") != "grey")
    
    # Color all remaining nodes in each partition
    for node in part0:
        if G.nodes[node].get("color", "grey") == "grey":
            G.nodes[node]["color"] = part0_color
            
    for node in part1:
        if G.nodes[node].get("color", "grey") == "grey":
            G.nodes[node]["color"] = part1_color
    
    return G


@register_task(
    name="blueSubgraph",
    required_properties=["has_colored_node"],
    required_pretransforms=[("color_some_random_nodes", {"color": "blue", "min_count": 2})], # No max_count defaults to n
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    parameter_schema={"target_color": "string"},
    description="Returns the subgraph induced by blue nodes (removes all non-blue nodes)."
)
def blue_subgraph(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Returns the subgraph induced by nodes of a specific color (default: blue).
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Dictionary with parameter "target_color" (default: "blue")
    
    Returns:
    - Subgraph containing only nodes of the target color and edges between them
    
    Index Transformation:
    - Removed nodes: All nodes that don't have the target color are removed
    - Remaining nodes: Nodes with the target color retain their original indices
    - The resulting graph will have non-consecutive node indices if non-colored nodes are removed
    - Example: If input has nodes [0, 1, 2, 3, 4] and nodes 1, 3 are blue,
               output will have nodes [1, 3] with their original indices
    """
    target_color = params.get("target_color", "blue")
    
    # Find all nodes with the target color
    blue_nodes = [node for node, data in G.nodes(data=True)
                  if data.get("color", "grey") == target_color]
    
    if not blue_nodes:
        # Return empty graph if no blue nodes found
        return nx.Graph()
    
    # Create subgraph induced by blue nodes
    return G.subgraph(blue_nodes).copy()


@register_task(
    name="mergeAtBlue",
    required_properties=[has_components(2), has_colored_nodes(2, "blue")],
    required_pretransforms=[("recolor_nodes", {"from_color": "orange", "to_color": "blue"})],
    preferred_generators=["random2Component"],
    parameter_schema={"merge_color": "string"},
    description="Merges two graph components at their blue nodes into a single node."
)
def merge_at_blue(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Merges two components of a graph at their blue nodes.
    Each component must have exactly one blue node.
    
    Parameters:
    - G: NetworkX graph with exactly 2 components
    - params: Dictionary with parameter "merge_color" (default: "blue")
    
    Returns:
    - Graph where the two blue nodes are merged into one
    
    Index Transformation:
    - The blue node from the first component keeps its index
    - The blue node from the second component is removed
    - All edges from the second blue node are redirected to the first blue node
    - All other nodes retain their original indices
    - Example: If components have blue nodes at indices 2 and 7,
               node 7 is removed and all its edges now connect to node 2
    """
    merge_color = params.get("merge_color", "blue")
    
    # Get connected components
    components = list(nx.connected_components(G))
    if len(components) != 2:
        raise ValueError(f"Expected exactly 2 components, found {len(components)}")
    
    # Find blue node in each component
    blue_nodes = []
    for component in components:
        component_blue_nodes = [node for node in component
                                if G.nodes[node].get("color", "grey") == merge_color]
        if len(component_blue_nodes) != 1:
            raise ValueError(f"Each component must have exactly one {merge_color} node")
        blue_nodes.append(component_blue_nodes[0])
    
    # Create a copy of the graph
    result = G.copy()
    
    # Keep the first blue node, merge the second into it
    keep_node, merge_node = blue_nodes[0], blue_nodes[1]
    
    # Redirect all edges from merge_node to keep_node
    for neighbor in list(G.neighbors(merge_node)):
        if neighbor != keep_node:  # Avoid self-loops unless they existed
            result.add_edge(keep_node, neighbor)
    
    # Remove the merged node
    result.remove_node(merge_node)
    
    return result


@register_task(
    name="complementGraph",
    preferred_generators=["random", "randomConnected", "randomTree", "star"],
    parameter_schema={},
    description="Returns the complement graph (inverts edge presence)."
)
def complement_graph(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Returns the complement of the input graph.
    The complement has an edge between nodes i and j if and only if
    the original graph does not have an edge between i and j.
    
    Parameters:
    - G: NetworkX graph to transform
    - params: Not used
    
    Returns:
    - Complement graph with same nodes but inverted edge set
    
    Index Transformation:
    - No index changes: All nodes retain their original indices
    - Only the edge set is modified (inverted)
    - Node colors and other attributes are preserved
    - Example: If input has edges [(0,1), (1,2)], complement might have
               [(0,2), (0,3), (1,3), (2,3)] depending on total nodes
    """
    # Create complement graph preserving node attributes
    complement = nx.complement(G)
    
    # Copy node attributes (including colors) from original graph
    for node in G.nodes:
        complement.nodes[node].update(G.nodes[node])
    
    return complement


@register_task(
    name="removeSameColorEdges",
    required_properties=[all_nodes_colored_with(2)],
    required_pretransforms=[("color_all_nodes", {"num_colors": 2})],
    preferred_generators=["random", "randomConnected", "randomTree"],
    parameter_schema={},
    description="Removes edges between nodes of the same color (requires all nodes colored with 2+ colors)."
)
def remove_same_color_edges(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Removes all edges between nodes that have the same color.
    Requires that every node is colored and at least 2 different colors are used.
    
    Parameters:
    - G: NetworkX graph where all nodes are colored
    - params: Not used
    
    Returns:
    - Graph with same-color edges removed
    
    Index Transformation:
    - No index changes: All nodes retain their original indices
    - Only edges between same-colored nodes are removed
    - All node attributes (including colors) are preserved
    - Example: If nodes 0,1 are blue and 2,3 are orange, edges (0,1) and (2,3)
               would be removed if they exist
    """
    # Check that all nodes are colored
    node_colors = {}
    for node, data in G.nodes(data=True):
        color = data.get("color", "grey")
        if color == "grey":
            raise ValueError("All nodes must be colored (non-grey) for this transformation")
        node_colors[node] = color
    
    # Check that at least 2 colors are used
    unique_colors = set(node_colors.values())
    if len(unique_colors) < 2:
        raise ValueError("At least 2 different colors must be used")
    
    # Create result graph
    result = G.copy()
    
    # Remove edges between nodes of the same color
    edges_to_remove = []
    for u, v in G.edges():
        if node_colors[u] == node_colors[v]:
            edges_to_remove.append((u, v))
    
    result.remove_edges_from(edges_to_remove)
    
    return result
