import sys
import os
pyverilog_path = os.path.join(os.path.dirname(__file__), 'pyverilog')
sys.path.insert(0, pyverilog_path)
from pyverilog.vparser.parser import parse
from pyverilog.vparser.ast import Assign, Always, BlockingSubstitution, NonblockingSubstitution, Partselect, Pointer
from io import StringIO

def find_root_signals(verilog_code, error_signals, preprocess_include=None, logger=None):
    """
    Find the root error signals from a list of error signals by analyzing signal dependencies.

    This function parses Verilog code to build a dependency graph, then identifies which
    error signals are the root causes (not driven by other error signals).

    Args:
        verilog_code (str): The Verilog code as a string
        error_signals (list): List of signal names that have errors
        preprocess_include (list, optional): List of include paths for preprocessing.
                                           Defaults to empty list.

    Returns:
        list: List of root error signals (signals not driven by other error signals)
    """
    error_signals = [e.replace("'", "").strip() for e in error_signals]
    if logger:
        logger.info(f"error_signals: {error_signals}")
    if preprocess_include is None:
        preprocess_include = []

    # Initialize dependency graph: key=source_signal, value=set of signals driven by source
    dependency_graph = {}

    def extract_signal_name(lvalue):
        """Extract signal name from lvalue node"""
        if isinstance(lvalue, (Partselect, Pointer)):
            return extract_signal_name(lvalue.var)  # Recursively handle nested nodes.
        return getattr(lvalue, 'name', None)

    def collect_sources(node):
        """Recursively collect all source signal names from expression"""
        sources = set()
        if isinstance(node, (Partselect, Pointer)):
            sources.update(collect_sources(node.var))  # Recursively handle nested nodes.
        elif hasattr(node, 'name') and isinstance(node.name, str):
            sources.add(node.name)

        for child in node.children():
            sources.update(collect_sources(child))
        return sources

    def build_dependency_graph(ast_node):
        """Build signal dependency graph from AST"""
        if isinstance(ast_node, Assign):
            # Handle continuous assignments: assign target = expression
            target = extract_signal_name(ast_node.left.var)
            if target:
                sources = collect_sources(ast_node.right)
                for source in sources:
                    if source not in dependency_graph:
                        dependency_graph[source] = set()
                    dependency_graph[source].add(target)

        elif isinstance(ast_node, (BlockingSubstitution, NonblockingSubstitution)):
            # Handle assignments in always blocks
            target = extract_signal_name(ast_node.left.var)
            if target:
                sources = collect_sources(ast_node.right)
                for source in sources:
                    if source not in dependency_graph:
                        dependency_graph[source] = set()
                    dependency_graph[source].add(target)

        # Recursively process child nodes
        for child in ast_node.children():
            build_dependency_graph(child)

    # Parse Verilog code
    try:
        ast, _ = parse(StringIO(verilog_code), preprocess_include=preprocess_include)
    except Exception as e:
        print(f"Error parsing Verilog code: {e}")
        with open("debug.v", "a") as f:
            f.write(verilog_code)
            f.write("\n")
        return {}

    # Validate module exists
    if not ast.description.definitions:
        print("No module definitions found in the Verilog code")
        return {}
    
    # collect all real signals
    def collect_all_signals(node, signals):
        """DFS over AST to collect all signal names."""
        if hasattr(node, 'name') and isinstance(node.name, str):
            signals.add(node.name)
        for c in node.children():
            collect_all_signals(c, signals)

    all_real_signals = set()
    collect_all_signals(ast.description.definitions[0], all_real_signals)

    # Filter out error signals that are not present in the code.
    error_signals = [s for s in error_signals if s in all_real_signals]
    if not error_signals:
        logger.error(f"All real signals: {all_real_signals}")
        logger.error(f"error_signals: {error_signals}")
        logger.error(f"verilog_code =\n{verilog_code}")
        logger.error(f"preprocess_include: {preprocess_include}")
        assert False, "No error signals found in the Verilog code"
    # Build dependency graph from the first module
    build_dependency_graph(ast.description.definitions[0])
    # print(dependency_graph)
    # Convert to set for efficient lookup
    error_signals_set = set(error_signals)

    # Filter root error signals (not driven by other error signals)
    root_signals = []
    for signal in error_signals:
        # Check if signal is driven by any other error signal
        driven_by_error = any(
            signal in dependency_graph.get(source, set())
            for source in error_signals_set
        )
        if not driven_by_error:
            root_signals.append(signal)

    root_signals_2_driven = {s:[] for s in root_signals}
    for signal in root_signals:
        for source in dependency_graph.keys():
            if signal in dependency_graph.get(source, set()):
                root_signals_2_driven[signal].append(source)

    return root_signals_2_driven

def analyze_verilog_file(file_path, error_signals, preprocess_include=None):
    """
    Analyze a Verilog file and return root error signals.
    This function is kept for backward compatibility and testing purposes.

    Args:
        file_path (str): Path to the Verilog file
        error_signals (list): List of error signal names
        preprocess_include (list, optional): Include paths for preprocessing

    Returns:
        list: Root error signals
    """
    if not os.path.exists(file_path):
        print(f"Error: File '{file_path}' does not exist.")
        return []

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            verilog_code = f.read()
    except Exception as e:
        print(f"Error reading file: {e}")
        return []

    root_signals = find_root_signals(verilog_code, error_signals, preprocess_include)

    print("="*60)
    print(f"Root Signal Analysis for Verilog File: {file_path}")
    print("="*60)
    print(f"Input error signals: {error_signals}")
    print(f"Root error signals: {root_signals}")
    print("="*60)

    return root_signals
