import re
import json
import logging
from typing import Dict, Any, List
from .base import BaseAgent
from .common_prompts import (
    TASK_CONTEXT, DEBUG_OPERATIONS_FORMAT, DEBUG_JSON_OUTPUT_FORMAT, 
    DEBUG_ANALYSIS_GUIDE, RTL_CODING_STANDARDS
)

from json_repair import repair_json

# =============================================================================
# SYNTAX DEBUG AGENT PROMPT COMPONENTS
# =============================================================================

SYNTAX_DEBUG_AGENT_BASE_PROMPT = """
<SYNTAX_DEBUG_AGENT_ROLE>
# Syntax Debug Agent - Verilog Syntax Error Resolution Specialist

## **Your Role in the System**
You are the **Syntax Debug Agent** in this multi-agent collaborative system. As the syntax correction specialist, you create debug operations from syntax errors, error messages, and documents to resolve Verilog compilation issues and ensure code correctness.

## **Primary Mission**
Analyze syntax errors in generated Verilog code and generate precise debug operations that fix compilation issues while maintaining functional integrity and architectural compliance.

## **Key Responsibilities**
* **Syntax Error Analysis**: Identify root causes of compilation failures and syntax violations
* **Document-Based Resolution**: Use allocated document sections to determine correct signal/module names
* **Precise Operations**: Generate targeted debug operations (delete_block/add_block) for syntax fixes
* **Architectural Preservation**: Ensure fixes maintain overall code structure and integration

## **Input Information Processing**
You will receive:
* **Document Context**: Relevant document sections containing correct signal/module specifications
* **Error Code**: Verilog code with line numbers showing compilation failures
* **Syntax Error Messages**: Detailed compiler error output for targeted fixes

## **Syntax-Specific Guidelines**
* **Context Usage**: Use document context to determine correct signal/module names
* **Port Connection Rules**: Follow comma rules strictly - no comma after last port
* **Error Pattern Recognition**: 
  - "Mixing positional and named connection" → comma after last port issue
  - "Instance attempts to connect to variable" → missing port declaration in module header
  - "Can't find definition of variable" → maybe wrong usage of macros (without backtick)
* **Focus Areas**: Undefined signals, incorrect module syntax, missing declarations
* **Output Format**: MUST return valid JSON format exactly as specified
</SYNTAX_DEBUG_AGENT_ROLE>
"""

# Complete prompt assembly using + operator for consistency
SYNTAX_DEBUG_COMPLETE_PROMPT = (
    TASK_CONTEXT +
    SYNTAX_DEBUG_AGENT_BASE_PROMPT +
    DEBUG_OPERATIONS_FORMAT +
    DEBUG_ANALYSIS_GUIDE +
    DEBUG_JSON_OUTPUT_FORMAT
    # RTL_CODING_STANDARDS
)

class SyntaxDebugAgent(BaseAgent):
    """Agent for debugging and fixing Verilog syntax errors"""
    
    def __init__(self, llm_client, logger=None):
        super().__init__(llm_client, "SyntaxDebugAgent", logger)
    
    def run(self, debug_input: Dict[str, Any]) -> Dict[str, Any]:
        """
        Debug and fix Verilog syntax errors
        
        Args:
            debug_input: {
                "error_code": "verilog code with errors",
                "error_messages": "syntax error messages from verification",
                "document_fragments": "relevant document sections for context",
                "task_name": "task name for file info (optional)"
            }
        
        Returns:
            {"fix_operations": [...], "success": bool, "fixed_code": str}
        """
        error_code = debug_input["error_code"]
        error_messages = debug_input["error_messages"]
        document_fragments = debug_input.get("document_fragments", "")
        task_name = debug_input.get("task_name", "unknown")
        
        self.logger.info(f"Starting syntax debug for code with {len(error_code)} characters")
        self.logger.info(f"Error messages: {error_messages[:200]}...")
        self.logger.info(f"Document context: {len(document_fragments)} characters")
        self.logger.info(f"Target file: {task_name}_top.v")
        
        try:
            # Generate syntax fixes
            self.logger.info("Starting syntax error analysis and fix generation")
            fix_operations = self._fix_syntax_errors(error_code, error_messages, document_fragments, task_name)
            self.logger.info(f"Generated {len(fix_operations)} fix operations")
            
            # Apply fixes
            if fix_operations:
                self.logger.info("Applying syntax fixes to code")
                fixed_code = self._apply_fixes(error_code, fix_operations)
                success = True
                self.logger.info(f"Successfully applied {len(fix_operations)} syntax fixes")
            else:
                fixed_code = error_code
                success = False
                self.logger.warning("No syntax fixes generated - returning original code")
            
            return {
                "fix_operations": fix_operations,
                "success": success,
                "fixed_code": fixed_code
            }
            
        except Exception as e:
            self.logger.error("="*80)
            self.logger.error(f"[SyntaxDebugAgent] EXCEPTION OCCURRED:")
            self.logger.error("-"*80)
            self.logger.error(f"Exception type: {type(e).__name__}")
            self.logger.error(f"Exception message: {str(e)}")
            self.logger.error("="*80)
            
            return {
                "fix_operations": [],
                "success": False,
                "fixed_code": error_code,
                "error": str(e)
            }
    
    def _fix_syntax_errors(self, code: str, error_messages: str, document_fragments: str, task_name: str) -> List[Dict[str, Any]]:
        """Generate fixes for syntax errors"""
        self.logger.info("Preparing input for syntax error analysis")
        code_with_lines = self._add_line_numbers(code)
        
        line_count = len(code.split('\n'))
        self.logger.info(f"Code lines count: {line_count}")
        self.logger.info(f"Error messages length: {len(error_messages)} characters")
        self.logger.info(f"Document context length: {len(document_fragments)} characters")
        
        # Build complete prompt using pre-assembled components
        system_prompt = SYNTAX_DEBUG_COMPLETE_PROMPT
        
        prompt = self.build_prompt(
            system_prompt + "\n\n" +
            "TASK: Fix syntax errors in the provided Verilog code.\n\n" +
            "--- DOCUMENT CONTEXT ---\n{document_fragments}\n\n" +
            "--- ERROR CODE (with line numbers) ---\n{code_with_lines}\n\n" +
            "--- SYNTAX ERROR MESSAGES ---\n{error_messages}\n\n" +
            "--- RTL CODING STANDARDS ---\n{rtl_coding_standards}\n\n" +
            "Think through your analysis step by step, then provide the fix JSON:",
            code_with_lines=code_with_lines,
            error_messages=error_messages,
            document_fragments=document_fragments,
            rtl_coding_standards=RTL_CODING_STANDARDS,
            task_name=task_name
        )
        
        self.logger.info("Calling LLM for syntax error analysis...")
        response = self.llm_complete(prompt)
        self.logger.info(f"Received LLM response with {len(response)} characters")
        
        # Parse JSON response - extract the last JSON block
        self.logger.info("Parsing JSON response...")
        try:
            result = self._parse_json_response(response)
            self.logger.info(f"Successfully parsed JSON with keys: {list(result.keys())}")
        except Exception as parse_error:
            self.logger.error(f"JSON parsing failed: {parse_error}")
            self.logger.error(f"Raw response (first 500 chars): {response[:500]}...")
            raise
        
        # Validate required fields
        if "fixes" not in result:
            self.logger.error(f"Missing 'fixes' field in parsed result: {result}")
            raise ValueError(f"Missing 'fixes' field in syntax error response: {result}")
        
        if not isinstance(result["fixes"], list):
            self.logger.error(f"'fixes' field is not a list: {type(result['fixes'])}")
            raise ValueError(f"Field 'fixes' must be a list: {result['fixes']}")
        
        self.logger.info(f"Validation successful - found {len(result['fixes'])} fix operations")
        return result["fixes"]
    
    def _parse_json_response(self, response: str) -> Dict[str, Any]:
        """Parse JSON response - simplified logic with json repair only"""
        self.logger.debug(f"Parsing JSON response: {response[:200]}...")
        
        # Find all JSON blocks in markdown code blocks
        json_blocks = []
        for match in re.finditer(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL):
            json_blocks.append(match.group(1))
        
        if not json_blocks:
            raise ValueError(f"No ```json ``` block found in LLM response")
        
        # Use the last JSON block found
        json_str = json_blocks[-1]
        
        # Try to parse JSON directly
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            self.logger.warning(f"Direct JSON parsing failed: {e}, attempting repair")
            
            # Try jsonrepair
            try:
                repaired_json = repair_json(json_str)
                return json.loads(repaired_json)
            except Exception as repair_error:
                raise ValueError(f"JSON repair failed: {repair_error}, original error: {e}")
    
    def _add_line_numbers(self, code: str) -> str:
        """Add line numbers to code for debugging"""
        lines = code.split('\n')
        numbered_lines = []
        for i, line in enumerate(lines, 1):
            numbered_lines.append(f"[{i:4d}]{line}")
        return '\n'.join(numbered_lines)
    
    def _apply_fixes(self, code: str, fixes: List[Dict[str, Any]]) -> str:
        """Apply fix operations in reverse line order to prevent line number conflicts"""
        lines = code.split('\n')
        original_line_count = len(lines)
        
        self.logger.info(f"Applying {len(fixes)} fixes to code with {original_line_count} lines")
        
        # Get all operations and sort by affected line (reverse order)
        line_changing_ops = []
        for fix in fixes:
            operation = fix.get("operation", "")
            if operation == "delete_block":
                affected_line = fix.get("start_line", 0)
                line_changing_ops.append((affected_line, fix))
            elif operation == "add_block":
                affected_line = fix.get("line", 0)
                line_changing_ops.append((affected_line, fix))
        
        # Sort by affected line in reverse order (highest line number first)
        line_changing_ops.sort(key=lambda x: x[0], reverse=True)
        
        # Apply operations in reverse line order, skip failed operations
        successful_fixes = 0
        failed_fixes = 0
        
        for affected_line, fix in line_changing_ops:
            operation = fix.get("operation", "")
            try:
                if operation == "delete_block":
                    lines = self._apply_delete_block(lines, fix)
                    self.logger.info(f"✅ Applied delete_block at lines {fix.get('start_line')}-{fix.get('end_line')}")
                    successful_fixes += 1
                elif operation == "add_block":
                    lines = self._apply_add_block(lines, fix)
                    self.logger.info(f"✅ Applied add_block at line {fix.get('line')}")
                    successful_fixes += 1
                else:
                    self.logger.warning(f"⚠️  Unknown operation '{operation}', skipping fix: {fix}")
                    failed_fixes += 1
                    
            except Exception as e:
                self.logger.error(f"❌ Failed to apply fix {operation}: {e}")
                self.logger.error(f"   Fix details: {fix}")
                self.logger.info(f"   Skipping this fix and continuing with others...")
                failed_fixes += 1
        
        final_line_count = len(lines)
        self.logger.info(f"Fix application summary: {successful_fixes} successful, {failed_fixes} failed")
        self.logger.info(f"Line count change: {original_line_count} -> {final_line_count} lines")
        
        return '\n'.join(lines)
    
    def _apply_delete_block(self, lines: List[str], fix: Dict[str, Any]) -> List[str]:
        """Apply delete block operation"""
        start_line = fix.get("start_line")
        end_line = fix.get("end_line")
        
        if start_line is None or end_line is None:
            raise ValueError(f"delete_block requires both 'start_line' and 'end_line': {fix}")
        
        # Convert to 0-based indices
        start_idx = max(0, start_line - 1)
        end_idx = min(len(lines) - 1, end_line - 1)
        
        if start_idx > end_idx or start_idx >= len(lines):
            self.logger.warning(f"Invalid line range {start_line}-{end_line} for delete_block")
            return lines
        
        # Delete lines (inclusive range)
        del lines[start_idx:end_idx + 1]
        self.logger.info(f"Deleted lines {start_line} to {end_line} (inclusive)")
        
        return lines
    
    def _apply_add_block(self, lines: List[str], fix: Dict[str, Any]) -> List[str]:
        """Apply add block operation"""
        line_num = fix.get("line", 0)
        content = fix.get("content", [])
        
        if not content:
            raise ValueError(f"Missing 'content' field in add_block operation: {fix}")
        
        if not isinstance(content, list):
            raise ValueError(f"Field 'content' must be a list of strings: {content}")
        
        # Determine insertion position
        if line_num == 0:
            # Insert at beginning
            insert_idx = 0
        elif line_num >= len(lines):
            # Insert at end
            insert_idx = len(lines)
        else:
            # Insert after specified line
            insert_idx = line_num
        
        # Insert all content lines
        for i, line in enumerate(content):
            lines.insert(insert_idx + i, line)
        
        self.logger.info(f"Inserted {len(content)} lines at position {insert_idx}")
        
        return lines
