import os
import re
import networkx as nx


def decode_graph_from_text(filepath, encoding_type=None, response=False):
    """
    Main entry point for decoding a graph from a text file.
    Determines the encoding format and calls the appropriate decoder.

    Parameters:
    - filepath (str): Path to the .txt file containing the graph description.
    - encoding_type (str, optional): The encoding format ('adjacency', 'incident', 'expert').
                                   If None, it will be detected from the content.
    - response (bool): Whether the file at filepath is a response and therefore
                     needs to have the answer section extracted.

    Returns:
    - G (networkx.Graph): The reconstructed graph with edges and node colors.
    """
    # Check if filepath is a file path or content string
    if os.path.exists(filepath):
        with open(filepath, "r", encoding="utf-8") as f:
            content = f.read()
    else:
        # Assume filepath is actually content for testing
        content = filepath

    # Extract the answer section if this is a response
    if response:
        try:
            content = extract_answer_section(content)
        except ValueError as e:
            raise ValueError(f"Failed to extract answer section: {e}")

    # Determine encoding type if not provided
    if encoding_type is None:
        if "In an undirected graph, (i,j) means" in content:
            encoding_type = "adjacency"
        elif "Node" in content and any(
            pattern in content
            for pattern in [
                "connected to nodes",
                "connected to node",
                "connected to:",
                "connected to [",
            ]
        ):
            encoding_type = "incident"
        elif "among A, B, " in content or "among nodes A, B, " in content:
            encoding_type = "expert"
        else:
            # Default to adjacency if can't determine
            encoding_type = "adjacency"

    # Call the appropriate decoder
    if encoding_type == "adjacency":
        return decode_adjacency_format(content)
    elif encoding_type == "incident":
        return decode_incident_format(content)
    else:
        raise ValueError(f"Unknown encoding type: {encoding_type}")


def decode_adjacency_format(text):
    """
    Decodes a graph representation in adjacency format using numeric node names.

    Format example:
    In an undirected graph, (i,j) means that node i and node j are connected with an undirected edge.
    G describes a graph among nodes 0, 1, 2, 3, 4.
    The edges in G are: (0,1) (1,2) (2,3) (3,4).

    Or with no edges:
    G describes a graph among nodes 1 and 3.
    The edges in G are: [There are no edges.]

    Parameters:
    - text (str): The text containing the graph description

    Returns:
    - G (networkx.Graph): The constructed graph
    """
    G = nx.Graph()

    # Extract nodes from node list
    node_match = re.search(r"nodes\s+(.*?)[.\n]", text, re.DOTALL)
    if node_match:
        nodes_str = node_match.group(1)
        # Handle "nodes 1 and 3" format
        nodes_str = nodes_str.replace(" and ", ", ")
        for n in nodes_str.split(","):
            try:
                # Convert "node 1" to just "1"
                n = n.strip().replace("node ", "")
                G.add_node(int(n))
            except ValueError:
                continue

    # Also extract any node numbers mentioned anywhere in the text
    # This helps with sparse descriptions
    all_nodes = re.findall(r"\bnode (\d+)\b", text, re.IGNORECASE)
    for node in all_nodes:
        try:
            G.add_node(int(node))
        except ValueError:
            continue

    # Check if the text explicitly states there are no edges
    no_edges = any(
        phrase in text.lower()
        for phrase in [
            "no edges",
            "there are no edges",
            "edges in g are: none",
            "edges in g are: []",
            "edges in g are: [ ]",
        ]
    )

    if not no_edges:
        # Extract all edges using regex - find all (x,y) patterns
        edges_pattern = r"\((\d+),\s*(\d+)\)"
        edge_matches = re.findall(edges_pattern, text)

        for u, v in edge_matches:
            try:
                u_int, v_int = int(u), int(v)
                G.add_edge(u_int, v_int)
            except ValueError:
                continue

    # Extract colors
    G = _extract_node_colors(G, text)

    return G


def decode_incident_format(text):
    """
    Decodes a graph representation in incident format using numeric node names.

    Handles various formats like:
    - Node 0 is connected to nodes 1, 2.
    - Node 3 is connected to node 9.
    - Node 1 is connected to: 5, 6, 7.
    - Node 4 is connected to nodes 7 and 8.
    - Node 11 is connected to [14].
    - Node 3 is connected to no nodes.

    Parameters:
    - text (str): The text containing the graph description

    Returns:
    - G (networkx.Graph): The constructed graph
    """
    G = nx.Graph()

    # Extract nodes from the node list
    node_match = re.search(r"nodes\s+(.*?)[.\n]", text, re.DOTALL)
    if node_match:
        nodes_str = node_match.group(1)
        for n in nodes_str.split(","):
            try:
                G.add_node(int(n.strip()))
            except ValueError:
                continue

    # Add nodes mentioned in the Node X is... statements
    node_refs = re.findall(r"Node (\d+)", text)
    for node in node_refs:
        try:
            node_id = int(node.strip())
            if node_id not in G:
                G.add_node(node_id)
        except ValueError:
            continue

    # Find all lines that describe node connections
    lines = text.split("\n")
    for line in lines:
        if "Node" in line and "connected to" in line:
            # Extract source node
            source_match = re.search(r"Node (\d+)", line)
            if not source_match:
                continue

            try:
                source_node = int(source_match.group(1))
            except ValueError:
                continue

            # Skip if this node explicitly has no connections
            if re.search(r"connected to (?:no nodes|\[ *\]|nothing)", line):
                continue

            # Find all numbers after "connected to"
            connected_part = line.split("connected to")[1]
            # Find all numbers in the remainder of the line
            target_nodes = re.findall(r"\b(\d+)\b", connected_part)

            for target in target_nodes:
                try:
                    target_node = int(target)
                    if (
                        target_node != source_node
                    ):  # Avoid self-loops if they're not intended
                        G.add_edge(source_node, target_node)
                except ValueError:
                    continue

    # Extract colors
    G = _extract_node_colors(G, text)

    return G


def _extract_node_colors(G, text):
    """
    Helper function to extract node colors from text and apply them to the graph.
    Works for both adjacency and incident formats with numeric node IDs.

    Parameters:
    - G (networkx.Graph): The graph to apply colors to
    - text (str): The text containing color information

    Returns:
    - G (networkx.Graph): The graph with colors applied
    """
    # Process each line for color information
    for line in text.splitlines():
        line = line.strip()

        if "None of the nodes are colored" in line or "All nodes are uncolored" in line:
            # No nodes are colored
            pass
        elif "The following nodes are colored" in line:
            # Pattern: "The following nodes are colored: 1, 2, 3."
            color = "blue"  # default single-color assumption
            node_ids = re.findall(r"\b(\d+)\b", line)
            for node in node_ids:
                try:
                    node_id = int(node)
                    if node_id not in G:
                        G.add_node(node_id)
                    G.nodes[node_id]["color"] = color
                except ValueError:
                    continue
        elif line.startswith("Node"):
            # Pattern: "Node 1 is blue." or similar
            match = re.match(r"Node (\d+) is (\w+)", line)
            if match:
                try:
                    node, color = int(match.group(1)), match.group(2).lower()
                    if node not in G:
                        G.add_node(node)
                    G.nodes[node]["color"] = color
                except ValueError:
                    continue

    # Default to grey for uncolored nodes
    for node in G.nodes:
        if "color" not in G.nodes[node]:
            G.nodes[node]["color"] = "grey"

    return G


def extract_answer_section(content):
    """
    Extracts the answer section from a response text using XML tags.

    This function looks for content between <answer></answer> XML tags and returns
    the content inside those tags.

    Parameters:
    - content (str): The full text content of the response file.

    Returns:
    - str: The extracted answer section (content between <answer></answer> tags).

    Raises:
    - ValueError: If no <answer></answer> tags are found or if tags are malformed.
    """
    # Clean up the content and normalize whitespace around tags
    content = content.strip()

    # Use regex to find content between <answer></answer> tags
    # This pattern handles:
    # - Optional whitespace around tag names
    # - Case insensitive matching
    # - Multiline content
    # - Tags on separate lines
    pattern = r"<\s*answer\s*>(.*?)<\s*/\s*answer\s*>"
    match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)

    if match:
        answer_content = match.group(1).strip()
        if not answer_content:
            raise ValueError(
                "Found <answer></answer> tags but they contain no content."
            )
        return answer_content

    # If no XML tags found, provide helpful error message
    # Check if there are partial tags to give more specific feedback
    if "<answer" in content.lower():
        if "</answer>" not in content.lower():
            raise ValueError(
                "Found opening <answer> tag but missing closing </answer> tag."
            )
        else:
            raise ValueError(
                "Found <answer> tags but could not extract content. Check for malformed XML."
            )
    elif "</answer>" in content.lower():
        raise ValueError(
            "Found closing </answer> tag but missing opening <answer> tag."
        )
    else:
        raise ValueError(
            "No <answer></answer> XML tags found in response. "
            "Expected format: <answer>content</answer>"
        )
