import os
import re
import json
import logging
from typing import Dict, Any, List
from .base import BaseAgent
from .common_prompts import (
    TASK_CONTEXT, DEBUG_OPERATIONS_FORMAT, DEBUG_JSON_OUTPUT_FORMAT, 
    DEBUG_ANALYSIS_GUIDE, RTL_CODING_STANDARDS
)

from json_repair import repair_json
from find_root_signal import find_root_signals
from waveform_tracer import WaveformTracer

# =============================================================================
# SEMANTIC DEBUG AGENT PROMPT COMPONENTS
# =============================================================================

SEMANTIC_DEBUG_AGENT_BASE_PROMPT = """
<SEMANTIC_DEBUG_AGENT_ROLE>
# Semantic Debug Agent - Verilog Logic and Functional Verification Specialist

## **Your Role in the System**
You are the **Semantic Debug Agent** in this multi-agent collaborative system. As the functional verification specialist, you create debug operations from semantic errors, error messages, and documents to resolve logic errors, timing issues, and behavioral problems while maintaining architectural integrity.

## **Primary Mission**
Analyze semantic and functional errors in Verilog code and generate precise debug operations that correct logic behavior, timing constraints, and functional specifications while preserving the overall system architecture.

## **Key Responsibilities**
* **Logic Error Analysis**: Identify functional bugs, timing violations, and behavioral mismatches
* **Document-Based Verification**: Use allocated document sections to validate functional requirements
* **Precise Corrections**: Generate targeted debug operations for semantic fixes
* **Architectural Integrity**: Ensure fixes maintain system-level correctness and integration

## **Input Information Processing**
You will receive:
* **Error Code**: Verilog code with line numbers showing functional issues
* **Semantic Error Messages**: Verification failures (simulation mismatches, logic errors, timing violations)
* **Document Context**: Relevant document sections containing functional specifications and timing requirements

## **Semantic-Specific Guidelines**
* **Functional Focus**: Focus on functional correctness and timing requirements
* **Context Analysis**: Use document context to understand expected signal behavior
* **Signal Correction**: Correct signal widths, logic operations, and module connections
* **Flow Verification**: Ensure proper signal flow and state machine behavior
* **Timing Compliance**: Address timing constraints and edge cases mentioned in documents
* **Architectural Integrity**: Maintain architectural integrity while fixing semantic issues
* **Output Format**: MUST return valid JSON format exactly as specified
</SEMANTIC_DEBUG_AGENT_ROLE>
"""

# Complete prompt assembly using + operator for consistency
SEMANTIC_DEBUG_COMPLETE_PROMPT = (
    TASK_CONTEXT +
    SEMANTIC_DEBUG_AGENT_BASE_PROMPT +
    DEBUG_OPERATIONS_FORMAT +
    DEBUG_ANALYSIS_GUIDE +
    DEBUG_JSON_OUTPUT_FORMAT +
    RTL_CODING_STANDARDS
)

class SemanticDebugAgent(BaseAgent):
    """Agent for debugging and fixing Verilog semantic errors"""
    
    def __init__(self, llm_client, logger=None, output_dir: str = ""):
        super().__init__(llm_client, "SemanticDebugAgent", logger)
        self.waveform_tracer = WaveformTracer()
        self.output_dir = output_dir
    
    def run(self, debug_input: Dict[str, Any]) -> Dict[str, Any]:
        """
        Debug and fix Verilog semantic errors
        
        Args:
            debug_input: {
                "error_code": "verilog code with errors",
                "error_messages": "semantic error messages from verification",
                "document_fragments": "relevant document sections for context",
                "task_name": "task name for file info (optional)"
            }
        
        Returns:
            {"fix_operations": [...], "success": bool, "fixed_code": str}
        """
        error_code = debug_input["error_code"]
        error_messages = debug_input["error_messages"]
        document_fragments = debug_input.get("document_fragments", "")
        task_name = debug_input.get("task_name", "unknown")
        
        self.logger.info(f"Starting semantic debug for code with {len(error_code)} characters")
        self.logger.info(f"Error messages: {error_messages[:200]}...")
        self.logger.info(f"Document context: {len(document_fragments)} characters")
        self.logger.info(f"Target file: {task_name}_top.v")
        
        try:
            # Generate semantic fixes
            self.logger.info("Starting semantic error analysis and fix generation")
            fix_operations = self._fix_semantic_errors(error_code, error_messages, document_fragments, task_name)
            self.logger.info(f"Generated {len(fix_operations)} fix operations")
            
            # Apply fixes
            if fix_operations:
                self.logger.info("Applying semantic fixes to code")
                fixed_code = self._apply_fixes(error_code, fix_operations)
                success = True
                self.logger.info(f"Successfully applied {len(fix_operations)} semantic fixes")
            else:
                fixed_code = error_code
                success = False
                self.logger.warning("No semantic fixes generated - returning original code")
            
            return {
                "fix_operations": fix_operations,
                "success": success,
                "fixed_code": fixed_code
            }
            
        except Exception as e:
            self.logger.error("="*80)
            self.logger.error(f"[SemanticDebugAgent] EXCEPTION OCCURRED:")
            self.logger.error("-"*80)
            self.logger.error(f"Exception type: {type(e).__name__}")
            self.logger.error(f"Exception message: {str(e)}")
            self.logger.error("="*80)
            
            return {
                "fix_operations": [],
                "success": False,
                "fixed_code": error_code,
                "error": str(e)
            }
    
    def _analyze_error_signals(self, code: str, error_messages: str, task_name: str) -> str:
        """Analyze error signals with waveform tracing."""
        self.logger.info("Starting advanced waveform trace analysis for error signals")
        
        # Build verification path under the configured output directory.
        verification_path = os.path.join(self.output_dir, str(task_name), "verification")
        
        try:
            # Use waveform trace analysis.
            waveform_analysis = self.waveform_tracer.analyze_error_signals(
                verilog_code=code,
                error_messages=error_messages,
                task_name=task_name,
                verification_path=verification_path,
                trace_level=2
            )
            
            self.logger.info("Waveform trace analysis completed successfully")
            return waveform_analysis
            
        except Exception as e:
            self.logger.warning(f"Waveform trace analysis failed: {e}, falling back to simple analysis")
            
            # Fall back to simple analysis.
            error_signals = re.findall(r"Output ['\"]?([^'\s\"]+)['\"]? has \d+ mismatches", error_messages)
            self.logger.info(f"Error signals: {error_signals}")
            preprocess_include = [os.path.join(self.output_dir, str(task_name), "verification")]
            try:
                root_signals_2_driven = find_root_signals(code, error_signals, preprocess_include, self.logger)
                return str(root_signals_2_driven)
            except Exception as fallback_error:
                self.logger.error(f"Fallback analysis also failed: {fallback_error}")
                return f"Error analysis failed: {str(e)}"
    
    def _fix_semantic_errors(self, code: str, error_messages: str, document_fragments: str, task_name: str) -> List[Dict[str, Any]]:
        """Generate fixes for semantic errors"""
        self.logger.info("Preparing input for semantic error analysis")
        code_with_lines = self._add_line_numbers(code)

        line_count = len(code.split('\n'))
        self.logger.info(f"Code lines count: {line_count}")
        self.logger.info(f"Error messages length: {len(error_messages)} characters")
        self.logger.info(f"Document context length: {len(document_fragments)} characters")

        error_signals_analysis_result = self._analyze_error_signals(code, error_messages, task_name)
        
        # Determine analysis result type and build analysis text.
        if isinstance(error_signals_analysis_result, dict) and len(error_signals_analysis_result) > 0:
            # Original root-signal analysis result.
            tem_str = ""
            for k, v in error_signals_analysis_result.items():
                tem_str += f"root signal {k} is driven by: {v}\n"
            error_signals_analysis = f"""--- ERROR SIGNALS ANALYSIS ---
The AST analysis tool has found the root error signals(not driven by other error signals).
The root error signals and the signals that drive them are following:
{tem_str}
If you think the root error signals' logic is correct, you can try to analyze the signals that drive them. They might be wrong.
"""
        elif isinstance(error_signals_analysis_result, str) and len(error_signals_analysis_result) > 0:
            # Waveform trace analysis result.
            error_signals_analysis = f"""--- ADVANCED WAVEFORM TRACE ANALYSIS ---
{error_signals_analysis_result}
"""
        else:
            error_signals_analysis = ""
        # temporarily disable waveform trace analysis
        # error_signals_analysis = ""
        # Build complete prompt using pre-assembled components
        system_prompt = SEMANTIC_DEBUG_COMPLETE_PROMPT
        
        prompt = self.build_prompt(
            system_prompt + "\n\n" +
            "TASK: Fix semantic/logic errors in the provided Verilog code.\n\n" +
            "--- DOCUMENT CONTEXT ---\n{document_fragments}\n\n" +
            "--- ERROR CODE (with line numbers) ---\n{code_with_lines}\n\n" +
            "--- SEMANTIC ERROR MESSAGES ---\n{error_messages}\n\n" +
            "{error_signals_analysis}" +
            "Think through your analysis step by step, then provide the fix JSON:",
            code_with_lines=code_with_lines,
            error_messages=error_messages,
            document_fragments=document_fragments,
            error_signals_analysis=error_signals_analysis,
            task_name=task_name
        )
        
        self.logger.info("Calling LLM for semantic error analysis...")
        response = self.llm_complete(prompt)
        self.logger.info(f"Received LLM response with {len(response)} characters")
        
        # Parse JSON response - extract the last JSON block
        self.logger.info("Parsing JSON response...")
        try:
            result = self._parse_json_response(response)
            self.logger.info(f"Successfully parsed JSON with keys: {list(result.keys())}")
        except Exception as parse_error:
            self.logger.error(f"JSON parsing failed: {parse_error}")
            self.logger.error(f"Raw response (first 500 chars): {response[:500]}...")
            raise
        
        # Validate required fields
        if "fixes" not in result:
            self.logger.error(f"Missing 'fixes' field in parsed result: {result}")
            raise ValueError(f"Missing 'fixes' field in semantic error response: {result}")
        
        if not isinstance(result["fixes"], list):
            self.logger.error(f"'fixes' field is not a list: {type(result['fixes'])}")
            raise ValueError(f"Field 'fixes' must be a list: {result['fixes']}")
        
        self.logger.info(f"Validation successful - found {len(result['fixes'])} fix operations")
        return result["fixes"]
    
    def _parse_json_response(self, response: str) -> Dict[str, Any]:
        """Parse JSON response - simplified logic with json repair only"""
        self.logger.debug(f"Parsing JSON response: {response[:200]}...")
        
        # Find all JSON blocks in markdown code blocks
        json_blocks = []
        for match in re.finditer(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL):
            json_blocks.append(match.group(1))
        
        if not json_blocks:
            raise ValueError(f"No ```json ``` block found in LLM response")
        
        # Use the last JSON block found
        json_str = json_blocks[-1]
        
        # Try to parse JSON directly
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            self.logger.warning(f"Direct JSON parsing failed: {e}, attempting repair")
            
            # Try jsonrepair
            try:
                repaired_json = repair_json(json_str)
                return json.loads(repaired_json)
            except Exception as repair_error:
                raise ValueError(f"JSON repair failed: {repair_error}, original error: {e}")
    
    def _add_line_numbers(self, code: str) -> str:
        """Add line numbers to code for debugging"""
        lines = code.split('\n')
        numbered_lines = []
        for i, line in enumerate(lines, 1):
            numbered_lines.append(f"[{i:4d}]{line}")
        return '\n'.join(numbered_lines)
    
    def _apply_fixes(self, code: str, fixes: List[Dict[str, Any]]) -> str:
        """Apply fix operations in reverse line order to prevent line number conflicts"""
        lines = code.split('\n')
        original_line_count = len(lines)
        
        self.logger.info(f"Applying {len(fixes)} fixes to code with {original_line_count} lines")
        
        # Get all operations and sort by affected line (reverse order)
        line_changing_ops = []
        for fix in fixes:
            operation = fix.get("operation", "")
            if operation == "delete_block":
                affected_line = fix.get("start_line", 0)
                line_changing_ops.append((affected_line, fix))
            elif operation == "add_block":
                affected_line = fix.get("line", 0)
                line_changing_ops.append((affected_line, fix))
        
        # Sort by affected line in reverse order (highest line number first)
        line_changing_ops.sort(key=lambda x: x[0], reverse=True)
        
        # Apply operations in reverse line order, skip failed operations
        successful_fixes = 0
        failed_fixes = 0
        
        for affected_line, fix in line_changing_ops:
            operation = fix.get("operation", "")
            try:
                if operation == "delete_block":
                    lines = self._apply_delete_block(lines, fix)
                    self.logger.info(f"✅ Applied delete_block at lines {fix.get('start_line')}-{fix.get('end_line')}")
                    successful_fixes += 1
                elif operation == "add_block":
                    lines = self._apply_add_block(lines, fix)
                    self.logger.info(f"✅ Applied add_block at line {fix.get('line')}")
                    successful_fixes += 1
                else:
                    self.logger.warning(f"⚠️  Unknown operation '{operation}', skipping fix: {fix}")
                    failed_fixes += 1
                    
            except Exception as e:
                self.logger.error(f"❌ Failed to apply fix {operation}: {e}")
                self.logger.error(f"   Fix details: {fix}")
                self.logger.info(f"   Skipping this fix and continuing with others...")
                failed_fixes += 1
        
        final_line_count = len(lines)
        self.logger.info(f"Fix application summary: {successful_fixes} successful, {failed_fixes} failed")
        self.logger.info(f"Line count change: {original_line_count} -> {final_line_count} lines")
        
        return '\n'.join(lines)
    
    def _apply_delete_block(self, lines: List[str], fix: Dict[str, Any]) -> List[str]:
        """Apply delete block operation"""
        start_line = fix.get("start_line")
        end_line = fix.get("end_line")
        
        if start_line is None or end_line is None:
            raise ValueError(f"delete_block requires both 'start_line' and 'end_line': {fix}")
        
        # Convert to 0-based indices
        start_idx = max(0, start_line - 1)
        end_idx = min(len(lines) - 1, end_line - 1)
        
        if start_idx > end_idx or start_idx >= len(lines):
            self.logger.warning(f"Invalid line range {start_line}-{end_line} for delete_block")
            return lines
        
        # Delete lines (inclusive range)
        del lines[start_idx:end_idx + 1]
        self.logger.info(f"Deleted lines {start_line} to {end_line} (inclusive)")
        
        return lines
    
    def _apply_add_block(self, lines: List[str], fix: Dict[str, Any]) -> List[str]:
        """Apply add block operation"""
        line_num = fix.get("line", 0)
        content = fix.get("content", [])
        
        if not content:
            raise ValueError(f"Missing 'content' field in add_block operation: {fix}")
        
        if not isinstance(content, list):
            raise ValueError(f"Field 'content' must be a list of strings: {content}")
        
        # Determine insertion position
        if line_num == 0:
            # Insert at beginning
            insert_idx = 0
        elif line_num >= len(lines):
            # Insert at end
            insert_idx = len(lines)
        else:
            # Insert after specified line
            insert_idx = line_num
        
        # Insert all content lines
        for i, line in enumerate(content):
            lines.insert(insert_idx + i, line)
        
        self.logger.info(f"Inserted {len(content)} lines at position {insert_idx}")
        
        return lines


def main():
    """Test waveform trace support for SemanticDebugAgent."""
    import logging
    from .backends import create_llm_client
    
    # Configure logging.
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    # Create an LLM client (test configuration).
    llm_client = create_llm_client(
        provider="deepseek",
        model="deepseek-ai/DeepSeek-V3",
        temperature=0.6
    )
    
    # Create the semantic debug agent.
    agent = SemanticDebugAgent(llm_client, logger)
    
    # Sample test data.
    test_verilog_code = """
module test_module (
    input clk,
    input reset,
    input [7:0] data_in,
    output reg [7:0] data_out,
    output reg valid
);

always @(posedge clk) begin
    if (reset) begin
        data_out <= 8'b0;
        valid <= 1'b0;
    end else begin
        data_out <= data_in;
        valid <= 1'b1;  // Potential logic error here.
    end
end

endmodule
"""
    
    test_error_messages = """
[Compiled Success]
[Function Check Failed]
==Tool Output==
First mismatch occurred at time 25. Output 'data_out' has 3 mismatches
First mismatch occurred at time 30. Output 'valid' has 2 mismatches
Mismatches: 5
==Tool Output End==
"""
    
    test_document_fragments = """
Module specification:
- When reset is asserted, outputs should be cleared
- Data should be latched on the rising edge
- valid should be high when valid data arrives
"""
    
    test_input = {
        "error_code": test_verilog_code,
        "error_messages": test_error_messages,
        "document_fragments": test_document_fragments,
        "task_name": "test_module"
    }
    
    print("=" * 80)
    print("Testing Semantic Debug Agent waveform tracing")
    print("=" * 80)
    
    try:
        result = agent.run(test_input)
        
        print("\nTest result:")
        print(f"Success: {result['success']}")
        print(f"Fix operations count: {len(result['fix_operations'])}")
        
        if result['fix_operations']:
            print("\nFix operations:")
            for i, fix in enumerate(result['fix_operations'], 1):
                print(f"  {i}. {fix.get('operation', 'unknown')}: {fix.get('description', 'no description')}")
        
        print(f"\nFixed code length: {len(result['fixed_code'])} characters")
        
        # Show waveform trace analysis if available.
        if hasattr(agent, 'last_waveform_analysis'):
            print("\nWaveform trace analysis:")
            print(agent.last_waveform_analysis[:500] + "..." if len(agent.last_waveform_analysis) > 500 else agent.last_waveform_analysis)
        
    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        # Cleanup resources.
        if hasattr(agent, 'waveform_tracer'):
            agent.waveform_tracer.cleanup()
    
    print("\nTest finished")


if __name__ == "__main__":
    main()
