"""
Pre-transformations framework for modifying graphs before main task transformations.

This module provides:
1. A registry of pre-transformations that can be applied to graphs
2. Functions to apply transformations with parameters
3. Utilities for composition of transformations
"""

import random
import networkx as nx
from typing import Dict, List, Callable, Any, Tuple, Optional

from scripts.utils.properties import has_degree

# Type definitions
GraphTransformer = Callable[[nx.Graph, Dict[str, Any]], nx.Graph]
PropertyList = List[str]


class PreTransformation:
    """
    A pre-transformation that can be applied to a graph.
    Now with support for parameterized provided properties.
    """

    def __init__(
        self,
        name: str,
        function: GraphTransformer,
        required_properties: PropertyList = None,
        provided_properties: PropertyList = None,
        provided_properties_fn: Callable[[Dict[str, Any]], PropertyList] = None,
        parameter_schema: Dict[str, Any] = None,
    ):
        self.name = name
        self.function = function
        self.required_properties = required_properties or []
        self.provided_properties = provided_properties or []
        self.provided_properties_fn = provided_properties_fn
        self.parameter_schema = parameter_schema or {}

    def get_provided_properties(self, params: Dict[str, Any] = None) -> List[str]:
        """
        Get the properties provided by this transformation with the given parameters.

        Parameters:
        - params: Dictionary of parameters for the transformation

        Returns:
        - List of property names provided by this transformation
        """
        params = params or {}

        # Start with the static provided properties
        result = list(self.provided_properties)

        # Add properties from the dynamic function if available
        if self.provided_properties_fn:
            result.extend(self.provided_properties_fn(params))

        return result

    def apply(self, G: nx.Graph, params: Dict[str, Any] = None) -> nx.Graph:
        """
        Apply the transformation to a graph with the given parameters.

        Parameters:
        - G: NetworkX graph to transform
        - params: Dictionary of parameters for the transformation

        Returns:
        - Transformed NetworkX graph
        """
        params = params or {}
        return self.function(G, params)

    def __str__(self) -> str:
        return f"PreTransformation(name={self.name})"

    def __repr__(self) -> str:
        return self.__str__()


# Transformation registry
PRETRANSFORMATIONS: Dict[str, PreTransformation] = {}


def register_pretransformation(
    name: str,
    required_properties: PropertyList = None,
    provided_properties: PropertyList = None,
    provided_properties_fn: Callable[[Dict[str, Any]], PropertyList] = None,
    parameter_schema: Dict[str, Any] = None,
):
    """
    Decorator to register a pre-transformation function.

    Parameters:
    - name: Name of the transformation
    - required_properties: List of properties required for the transformation
    - provided_properties: List of properties the transformation adds to the graph
    - provided_properties_fn: Function that returns properties based on parameters
    - parameter_schema: Dictionary describing the parameters for the transformation
    """

    def decorator(func: GraphTransformer):
        PRETRANSFORMATIONS[name] = PreTransformation(
            name=name,
            function=func,
            required_properties=required_properties,
            provided_properties=provided_properties,
            provided_properties_fn=provided_properties_fn,
            parameter_schema=parameter_schema,
        )
        return func

    return decorator


# -------------------------
# Basic pre-transformations
# -------------------------


@register_pretransformation(
    name="color_random_node",
    provided_properties=["has_colored_node"],
    parameter_schema={"color": "string"},
)
def color_random_node(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors a random node in the graph.

    Parameters:
    - G: NetworkX graph
    - params: Dictionary with parameter "color" (default: "blue")

    Returns:
    - Modified graph with one randomly colored node
    """
    color = params.get("color", "blue")
    node = random.choice(list(G.nodes))
    G.nodes[node]["color"] = color
    return G


@register_pretransformation(
    name="color_bipartition_seeds",
    required_properties=["bipartite"],
    provided_properties=["bipartition_seeds_colored"],
    parameter_schema={"first_color": "string", "second_color": "string"},
)
def color_bipartition_seeds(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors one random node in each partition of a bipartite graph.

    Parameters:
    - G: NetworkX graph to transform (must be bipartite)
    - params: Dictionary with parameters:
      - "first_color": color for first partition (default: "blue")
      - "second_color": color for second partition (default: "orange")

    Returns:
    - Modified graph with one node in each partition colored
    """
    first_color = params.get("first_color", "blue")
    second_color = params.get("second_color", "orange")

    # Verify the graph is bipartite
    if not nx.is_bipartite(G):
        raise ValueError("Graph must be bipartite")

    # Get the two partitions
    try:
        part0, part1 = nx.bipartite.sets(G)
    except Exception as exc:
        raise ValueError("Failed to identify bipartite partitions") from exc

    # Randomly select one node from each partition
    if not part0 or not part1:
        raise ValueError("Both partitions must have at least one node")

    node0 = random.choice(list(part0))
    node1 = random.choice(list(part1))

    # Color the selected nodes
    G.nodes[node0]["color"] = first_color
    G.nodes[node1]["color"] = second_color

    return G


@register_pretransformation(
    name="color_some_leaves",
    required_properties=["connected", has_degree(1)],  # Ensure the graph has leaf nodes
    provided_properties=["has_colored_leaves"],
    provided_properties_fn=lambda params: [
        f"has_colored_leaves_{params.get('count', 2)}"
    ],
    parameter_schema={"color": "string", "count": "int"},
)
def color_existing_leaves(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors existing leaf nodes (degree 1) in a graph.

    Parameters:
    - G: NetworkX graph (must be connected and have leaf nodes)
    - params: Dictionary with parameters:
      - "color": color to use (default: "blue")
      - "count": number of leaves to color (default: 2)

    Returns:
    - Modified graph with colored leaves
    """
    color = params.get("color", "blue")
    count = params.get("count", 2)

    # Find all leaf nodes (degree 1)
    leaf_nodes = [node for node in G.nodes if G.degree[node] == 1]

    if len(leaf_nodes) < count:
        raise ValueError(
            f"Not enough leaf nodes ({len(leaf_nodes)}) to color {count} leaves"
        )

    # Randomly select 'count' leaves to color
    selected_leaves = random.sample(leaf_nodes, count)

    # Color the selected leaves
    for node in selected_leaves:
        G.nodes[node]["color"] = color

    return G


@register_pretransformation(
    name="color_some_random_nodes",
    provided_properties=["has_colored_node"],
    provided_properties_fn=lambda params: (
        [f"has_colored_nodes_{params.get('count', 2)}_{params.get('color', 'blue')}"]
        if params.get("count")
        else []
    ),
    parameter_schema={
        "color": "string",
        "count": "int",
        "min_count": "int",
        "max_count": "int",
    },
)
def color_some_random_nodes(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors a specified number of random nodes in the graph.

    Parameters:
    - G: NetworkX graph
    - params: Dictionary with parameters:
      - "color": color to use (default: "blue")
      - "count": exact number of nodes to color (optional)
      - "min_count": minimum nodes to color (default: 2)
      - "max_count": maximum nodes to color (default: n)

    Returns:
    - Modified graph with colored nodes
    """
    color = params.get("color", "blue")
    count = params.get("count")
    min_count = params.get("min_count", 2)
    max_count = params.get("max_count", len(G.nodes))

    # Determine number of nodes to color
    if count is not None:
        nodes_to_color = count
    else:
        # Random between min_count and min(max_count, n)
        max_possible = min(max_count, len(G.nodes))
        if min_count > max_possible:
            raise ValueError(
                f"min_count ({min_count}) exceeds available nodes ({len(G.nodes)})"
            )
        nodes_to_color = random.randint(min_count, max_possible)

    if nodes_to_color > len(G.nodes):
        raise ValueError(
            f"Cannot color {nodes_to_color} nodes in a graph with only {len(G.nodes)} nodes"
        )

    # Randomly select nodes to color
    selected_nodes = random.sample(list(G.nodes), nodes_to_color)

    # Color the selected nodes
    for node in selected_nodes:
        G.nodes[node]["color"] = color

    return G


@register_pretransformation(
    name="color_all_nodes",
    provided_properties=["all_nodes_colored"],
    provided_properties_fn=lambda params: [
        f"has_multiple_colors_{params.get('num_colors', 2)}"
    ],
    parameter_schema={"num_colors": "int", "colors": "list"},
)
def color_all_nodes(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Colors all nodes in the graph using multiple colors.

    Parameters:
    - G: NetworkX graph
    - params: Dictionary with parameters:
      - "num_colors": number of different colors to use (default: 2)
      - "colors": list of specific colors to use (optional)

    Returns:
    - Modified graph with all nodes colored
    """
    num_colors = params.get("num_colors", 2)
    colors = params.get("colors")

    if not colors:
        # Default color palette
        default_colors = [
            "blue",
            "orange",
            "red",
            "green",
            "purple",
            "yellow",
            "pink",
            "brown",
        ]
        colors = default_colors[:num_colors]

    if len(colors) < num_colors:
        raise ValueError(f"Need {num_colors} colors but only {len(colors)} provided")

    # Distribute colors roughly evenly
    nodes = list(G.nodes)
    random.shuffle(nodes)  # Randomize node order

    for i, node in enumerate(nodes):
        color_index = i % num_colors
        G.nodes[node]["color"] = colors[color_index]

    return G


@register_pretransformation(
    name="recolor_nodes",
    required_properties=["has_colored_node"],
    parameter_schema={"from_color": "string", "to_color": "string"},
)
def recolor_nodes(G: nx.Graph, params: Dict[str, Any]) -> nx.Graph:
    """
    Recolors all nodes of one color to another color.

    Parameters:
    - G: NetworkX graph
    - params: Dictionary with parameters:
      - "from_color": color to change from
      - "to_color": color to change to

    Returns:
    - Modified graph with recolored nodes
    """
    from_color = params.get("from_color", "orange")
    to_color = params.get("to_color", "blue")

    for node, data in G.nodes(data=True):
        if data.get("color", "grey") == from_color:
            G.nodes[node]["color"] = to_color

    return G


# --------------------------------------
# Functions to work with transformations
# --------------------------------------


def apply_pretransformation(
    G: nx.Graph, transformation_name: str, params: Dict[str, Any] = None
) -> nx.Graph:
    """
    Apply a named pre-transformation to a graph.

    Parameters:
    - G: NetworkX graph to transform
    - transformation_name: Name of the registered transformation
    - params: Parameters for the transformation

    Returns:
    - Transformed NetworkX graph
    """
    if transformation_name not in PRETRANSFORMATIONS:
        raise ValueError(f"Unknown transformation: {transformation_name}")

    transformation = PRETRANSFORMATIONS[transformation_name]
    return transformation.apply(G, params)


def apply_pretransformations(
    G: nx.Graph, transformations: List[Tuple[str, Dict[str, Any]]]
) -> nx.Graph:
    """
    Apply a sequence of pre-transformations to a graph.

    Parameters:
    - G: NetworkX graph to transform
    - transformations: List of (transformation_name, params) tuples

    Returns:
    - Transformed NetworkX graph
    """
    result = G.copy()
    for name, params in transformations:
        result = apply_pretransformation(result, name, params)
    return result


def list_pretransformations() -> List[str]:
    """List all registered pre-transformation names."""
    return list(PRETRANSFORMATIONS.keys())


def get_pretransformation(name: str) -> Optional[PreTransformation]:
    """Get a pre-transformation by name."""
    return PRETRANSFORMATIONS.get(name)
