import os
import subprocess
import tempfile
import logging
import re
from typing import Dict, Any, Tuple, Optional, List
import shutil

class VerilogVerifier:
    """Verilog code verifier based on the RealBench flow."""
    
    def __init__(self, logger: logging.Logger):
        self.logger = logger
        self.verification_timeout = 4 * 60  # 4-minute timeout
        
    def verify_code(self, code: str, problem_id: str, verification_dir: str) -> Dict[str, Any]:
        """
        Verify Verilog code.
        
        Args:
            code: Verilog code to verify
            problem_id: Problem ID
            verification_dir: Verification directory path
            
        Returns:
            Dict containing verification results
        """
        self.logger.info(f"Starting verification for problem ID: {problem_id}")
        try:            
            # Find and replace top file.
            top_file = self._find_top_file(verification_dir, problem_id)
            if not top_file:
                self.logger.warning("Top file not found, skipping verification")
                top_file = os.path.join(verification_dir, f"{problem_id}_top.sv")
                # return {
                #     "syntax": 0,
                #     "semantic": 0,
                #     "syntax_err": "Verification files not found",
                #     "semantic_err": "Verification files not found"
                # }
            
            # Write code to be verified.
            with open(top_file, 'w') as f:
                f.write(code)
            
            # Run verification (with retries).
            return self._run_verification(verification_dir, top_file)
            
        except Exception as e:
            self.logger.error(f"Verification failed: {e}")
            return {
                "syntax": 0,
                "semantic": 0,
                "syntax_err": f"Verification error: {str(e)}",
                "semantic_err": f"Verification error: {str(e)}"
            }
    
    def _find_top_file(self, temp_dir: str, problem_id: str) -> Optional[str]:
        """Find the top file."""
        # Common top file name patterns.
        possible_names = [
            f"{problem_id}_top.sv",
            f"{problem_id}_top.v",
        ]
        
        for name in possible_names:
            path = os.path.join(temp_dir, name)
            if os.path.exists(path):
                return path
        return None

    def _run_verification(self, temp_dir: str, top_file: str) -> Dict[str, Any]:
        """Run verification command."""
        try:
            # Check Makefile existence.
            makefile_path = os.path.join(temp_dir, 'Makefile')
            if not os.path.exists(makefile_path):
                assert 0, "Makefile not found"
            
            # Run make command.
            result = subprocess.run(
                f"cd {temp_dir} && make all",
                shell=True,
                timeout=self.verification_timeout,
                stderr=subprocess.PIPE,
                stdout=subprocess.PIPE,
                text=True
            )
            
            return self._analyze_verification_result(result)
            
        except subprocess.TimeoutExpired:
            assert 0, "Verification timed out"
        except Exception as e:
            assert 0, f"Verification command failed: {str(e)}"
    
    def _analyze_verification_result(self, result: subprocess.CompletedProcess) -> Dict[str, Any]:
        """Analyze verification result."""
        syntax = 1
        semantic = 1
        syntax_err = ""
        semantic_err = ""
        
        # Analyze syntax errors in stderr.
        if result.stderr:
            syntax_errors = self._extract_syntax_errors(result.stderr)
            if syntax_errors:
                syntax = 0
                semantic = 0  # Syntax errors usually imply semantic failures.
                syntax_err = "\n".join(syntax_errors)
        
        # Analyze semantic errors in stdout.
        if result.stdout:
            semantic_errors = self._extract_semantic_errors(result.stdout)
            if semantic_errors:
                semantic = 0
                semantic_err = "\n".join(semantic_errors)
        
        # Check return code.
        if result.returncode != 0:
            if syntax == 1:  # If no syntax error was detected previously.
                syntax = 0
                syntax_err += f"\nReturn code: {result.returncode}"
        
        return {
            "syntax": syntax,
            "semantic": semantic,
            "syntax_err": syntax_err.strip(),
            "semantic_err": semantic_err.strip(),
            "return_code": result.returncode,
            "stdout": result.stdout,
            "stderr": result.stderr
        }
    
    def _extract_syntax_errors(self, stderr: str) -> List[str]:
        """Extract syntax errors from stderr."""
        errors = []
        
        # Common Verilog syntax error patterns.
        error_patterns = [
            r'%Error[^:]*:([^\\n]+)',
            r'%Warning[^:]*:([^\\n]+)',
            r'ERROR[^:]*:([^\\n]+)',
            r'Error[^:]*:([^\\n]+)',
            r'syntax error[^\\n]*',
            r'parse error[^\\n]*',
            r'compilation failed[^\\n]*'
        ]
        if type(stderr) == str:
            tb_msg = stderr
        else:
            tb_msg = stderr.decode()
        for line in tb_msg.split('\n'):
            if line.startswith(f"%Error") or line.startswith(f"%Warning"):
                errors.append(line)
        
        return list(set(errors))  # De-duplicate.
    
    def _extract_semantic_errors(self, stdout: str) -> List[str]:
        """Extract semantic errors from stdout."""
        errors = []
        
        # Common semantic error patterns.
        if type(stdout) == str:
            tb_msg = stdout
        else:
            tb_msg = stdout.decode()
        for line in tb_msg.split('\n'):
            if "Hint: Output" in line and "no mismatches" in line:
                continue
            elif "Hint: Output" in line and "mismatches" in line:
                errors.append(line[6:])
        
        return list(set(errors))  # De-duplicate.
    