"""
Waveform tracing tool, adapted from VerilogCoder as a standalone utility.
Used to analyze semantic error signals in Verilog code.
"""

import os
import re
import shutil
import subprocess
from typing import List, Dict, Tuple, Any, Optional
from vcd_waveform_analyzer import parse_mismatch, get_tabular
from debug_graph_analyzer import DebugGraph


class WaveformTracer:
    """Standalone waveform trace analyzer."""
    
    def __init__(self, workdir: str = "./waveform_trace_tmp/"):
        self.workdir = workdir
        self.graph_tracer = None
        self.cur_graph_verilog = ""
        
        # Ensure working directory exists.
        os.makedirs(workdir, exist_ok=True)
    
    def analyze_error_signals(
        self, 
        verilog_code: str, 
        error_messages: str, 
        task_name: str,
        verification_path: str,
        trace_level: int = 2
    ) -> str:
        """
        Analyze error signals as a replacement for _analyze_error_signals.
        
        Args:
            verilog_code: Verilog code
            error_messages: Error messages
            task_name: Task name
            verification_path: Verification path
            trace_level: Trace depth
            
        Returns:
            Analysis result string
        """
        # Check required files.
        expected_vcd = os.path.join(verification_path, "logs", "wave.vcd")
        os.makedirs(os.path.join(verification_path, "waveform_trace_tmp"), exist_ok=True)
        self.vcd_file_path = os.path.join(verification_path, "waveform_trace_tmp", "wave.vcd")
        self.verilog_file_path = os.path.join(verification_path, "waveform_trace_tmp", f"{task_name}_top.sv")
        if not os.path.exists(expected_vcd):
            return f"[Error] wave.vcd not found at {expected_vcd}! Please run simulation first!"
        
        # Copy VCD file into working directory.
        shutil.copy2(expected_vcd, self.vcd_file_path)
        
        # Save Verilog code to file.
        with open(self.verilog_file_path, 'w') as f:
            f.write(verilog_code)
        
        # try:
        return self._perform_waveform_trace(error_messages, trace_level)
        # except Exception as e:
            # return f"[Error] Waveform trace failed: {str(e)}"
    
    def _perform_waveform_trace(self, error_messages: str, trace_level: int) -> str:
        """Run waveform trace analysis."""
        # Check whether functionality is correct.
        if self._check_functionality(error_messages):
            return "[Waveform Tracer]: No mismatched signals!"
        
        # Parse mismatched signals.
        mismatch_columns, offset = parse_mismatch(test_output=error_messages)
        if not mismatch_columns:
            return "[Waveform Tracer]: No mismatch signals found in error messages!"
        
        # Build AST graph.
        if self.cur_graph_verilog != self._get_verilog_content():
            # Set include path to parse sd_defines.v correctly.
            include_path = os.path.dirname(os.path.dirname(self.verilog_file_path))
            self.graph_tracer = DebugGraph([self.verilog_file_path], include_dirs=[include_path])
            self.cur_graph_verilog = self._get_verilog_content()
        
        # Trace signal relations.
        traced_signals_map, signal_level_tracer = self.graph_tracer.get_k_control_signals(
            target_signals=mismatch_columns,
            k=trace_level,
            signal_only=True
        )
        
        # Build signal-trace string.
        traced_signal_str = self._generate_signal_trace_string(signal_level_tracer)
        
        # Collect all traced signals.
        all_traced_signals = list(traced_signals_map.keys())
        
        # Add input ports.
        input_ports = self._get_input_ports()
        for input_port in input_ports:
            if input_port not in mismatch_columns:
                mismatch_columns.append(input_port)
            if input_port not in all_traced_signals:
                all_traced_signals.append(input_port)
        
        # Generate waveform table.
        waveform_table_str = get_tabular(
            method='dataframe',
            vcd_path=self.vcd_file_path,
            mismatch_columns=all_traced_signals,
            offset=offset,
            ori_mismatch_columns=mismatch_columns
        )
        
        # Build full output.
        return self._build_trace_output(
            traced_signal_str, 
            waveform_table_str, 
            input_ports, 
            all_traced_signals,
            trace_level,
            offset
        )
    
    def _check_functionality(self, error_messages: str) -> bool:
        """Check whether the design passes functional checks."""
        error_signals = re.findall(r"Output ['\"]?([^'\s\"]+)['\"]? has \d+ mismatches", error_messages)
        return len(error_signals) == 0
    
    def _get_verilog_content(self) -> str:
        """Get Verilog file content."""
        if os.path.exists(self.verilog_file_path):
            with open(self.verilog_file_path, 'r') as f:
                return f.read()
        return ""
    
    def _get_input_ports(self) -> List[str]:
        """Extract input ports from Verilog code."""
        verilog_content = self._get_verilog_content()
        if not verilog_content:
            return []
        
        input_ports = []
        lines = verilog_content.splitlines()
        
        for line in lines:
            line = line.strip()
            if 'input' not in line:
                continue
            
            line = line.replace(",", " ")
            contents = line.split()
            
            for i in range(len(contents)):
                if i >= len(contents):
                    break
                    
                cur_text = contents[i]
                if cur_text == "//":
                    break
                    
                if cur_text == 'input':
                    i += 1
                    while i < len(contents) and (contents[i] == 'logic' or not re.match(r'^[a-zA-Z]', contents[i])):
                        i += 1
                    if i < len(contents):
                        input_ports.append(contents[i])
                    break
        
        return input_ports
    
    def _generate_signal_trace_string(self, signal_level_tracer: List[List[str]]) -> str:
        """Generate signal trace relation string."""
        traced_signal_str = "[Signal Traces] Backtrace control signal relations.\n"
        
        for bt in range(len(signal_level_tracer) - 1, -1, -1):
            if bt == len(signal_level_tracer) - 1:
                for signal_rel in signal_level_tracer[bt]:
                    traced_signal_str += signal_rel + "\n"
            
            header_space = "-" * (len(signal_level_tracer) - 1 - bt)
            for signal_rel in signal_level_tracer[bt]:
                traced_signal_str += header_space + signal_rel
                if bt == 0:
                    traced_signal_str += " (*last output port level)"
                traced_signal_str += "\n"
        
        traced_signal_str += "\n"
        return traced_signal_str
    
    def _build_trace_output(
        self, 
        traced_signal_str: str,
        waveform_table_str: str,
        input_ports: List[str],
        all_traced_signals: List[str],
        trace_level: int,
        offset: int
    ) -> str:
        """Build the full trace output."""
        # Get Verilog code.
        verilog_content = self._get_verilog_content()
        logic_str = f"[Verilog of DUT]:\n```verilog\n{verilog_content}\n```\n"
        
        # Build waveform description.
        waveform_desc = (
            f"[Signal Waveform]: <signal>_tb is the given testbench signal and can not be changed! "
            f"<signal>_ref is the golden, and <signal>_dut is the generated verilog file waveform. "
            f"Check the mismatched signal waveform and its traced signals. "
            f"The clock cycle (clk) is 10ns and toggles every 5ns. "
            f"'-' means unknown during simulation. "
            f"If the '-' is the reason of mismatched signal, please check the reset and assignment block.\n"
            # f"[Testbench Input Port Signal to Module]: {', '.join(input_ports)}\n"
            f"[Traced Signals]: {', '.join(all_traced_signals)}\n"
            f"[Table Waveform in hexadecimal format]\n{waveform_table_str}"
        )
        
        # Build hint text.
        hint = (
            # f"\n\n[Note] You can not change the [testbench input signal]: ({', '.join(input_ports)})! "
            # f"Modify the module implementation considering the input signals.\n"
            f"[Hint] Firstly, identify the time of mismatched signals, and only focus on the mismatched signals in the waveform firstly. "
            f"Then, explain the related signals and their transitions in the waveform table. "
            f"Don't correct signals without mismatch in the table waveforms. "
            # f"If the information is not enough for correct the functional error, "
            # f"try to trace more relevant signals using trace_level >{trace_level} for waveform_trace_tool."
        )
        
        # If mismatch occurs early, add initialization hint.
        # if offset <= 5:
        #     hint += (
        #         f"\n\n[Debug report]: The mismatch happened at the beginning. "
        #         f"Check and set the correct initial value for mismatched signal.\n"
        #         f"### Example of initialize the signal to 0 ###\n\n"
        #         f"logic [3:0] a;\ninitial\n  a=4'b0000;\n\n"
        #         f"### End example ###"
        #     )
        
        # return logic_str + waveform_desc + hint + (
        #     f"\n\nThought: If you know how to correct the functional error, "
        #     f"start to correct the code and run verilog_simulation_tool again."
        # )
        return waveform_desc + hint

    def cleanup(self):
        """Clean up temporary files."""
        if os.path.exists(self.workdir):
            shutil.rmtree(self.workdir)
