#!/usr/bin/env python3
"""
Judgment Operator for Code Style Conversion.

Implements the JO framework for code style enforcement:
- Allow: if checker passes
- Edit: apply precedent-based patches + formatters
- Escalate: if repeated edits fail
- Deny: if code becomes invalid
"""

import subprocess
import tempfile
import os
import ast
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum

from style_checker import StyleChecker, StyleViolation, CheckResult


class JOOutcome(Enum):
    ALLOW = "allow"
    EDIT = "edit"
    ESCALATE = "escalate"
    DENY = "deny"


@dataclass
class JOResult:
    """Result of JO processing."""
    outcome: JOOutcome
    output_code: str
    iterations: int
    violations_fixed: int
    edit_distance: int  # Token-level changes


@dataclass
class Precedent:
    """A style correction precedent."""
    rule: str  # The rule this fixes (e.g., "N802", "D103")
    pattern: str  # AST pattern or context signature
    patch_type: str  # Type of patch: "format", "rename", "docstring", "import"
    description: str


class StyleJO:
    """
    Judgment Operator for code style enforcement.

    Uses a two-stage edit approach:
    1. Canonical formatter pass (Black, isort)
    2. Precedent-guided patching for naming/docstrings
    """

    def __init__(self, max_iterations: int = 5, line_length: int = 88):
        self.max_iterations = max_iterations
        self.line_length = line_length
        self.checker = StyleChecker(max_line_length=line_length)
        self.precedents = self._build_precedents()

    def _build_precedents(self) -> Dict[str, Precedent]:
        """Build precedent library for common violations."""
        return {
            # Missing docstrings
            "D100": Precedent("D100", "module", "docstring", "Add module docstring"),
            "D101": Precedent("D101", "class", "docstring", "Add class docstring"),
            "D102": Precedent("D102", "method", "docstring", "Add method docstring"),
            "D103": Precedent("D103", "function", "docstring", "Add function docstring"),

            # Naming conventions
            "N802": Precedent("N802", "function", "rename", "Function name should be lowercase"),
            "N803": Precedent("N803", "argument", "rename", "Argument name should be lowercase"),
            "N806": Precedent("N806", "variable", "rename", "Variable should be lowercase"),

            # Import issues
            "I001": Precedent("I001", "import", "import", "Import block unsorted"),

            # Formatting (handled by Black)
            "E501": Precedent("E501", "line", "format", "Line too long"),
            "E302": Precedent("E302", "blank", "format", "Expected 2 blank lines"),
        }

    def process(self, code: str) -> JOResult:
        """
        Process code through JO pipeline.

        Returns JOResult with outcome and processed code.
        """
        original_code = code
        current_code = code
        total_fixed = 0

        for iteration in range(self.max_iterations):
            # Check current state
            result = self.checker.check(current_code)

            if result.passed:
                # ALLOW - code passes all checks
                edit_dist = self._compute_edit_distance(original_code, current_code)
                return JOResult(
                    outcome=JOOutcome.ALLOW if iteration == 0 else JOOutcome.EDIT,
                    output_code=current_code,
                    iterations=iteration,
                    violations_fixed=total_fixed,
                    edit_distance=edit_dist
                )

            # Try to fix violations
            new_code, fixed = self._apply_edits(current_code, result.violations)

            if new_code == current_code:
                # No progress - ESCALATE
                edit_dist = self._compute_edit_distance(original_code, current_code)
                return JOResult(
                    outcome=JOOutcome.ESCALATE,
                    output_code=current_code,
                    iterations=iteration + 1,
                    violations_fixed=total_fixed,
                    edit_distance=edit_dist
                )

            # Verify syntax is still valid
            if not self._is_valid_python(new_code):
                # Edit broke syntax - DENY
                edit_dist = self._compute_edit_distance(original_code, current_code)
                return JOResult(
                    outcome=JOOutcome.DENY,
                    output_code=current_code,
                    iterations=iteration + 1,
                    violations_fixed=total_fixed,
                    edit_distance=edit_dist
                )

            current_code = new_code
            total_fixed += fixed

        # Max iterations reached - ESCALATE
        edit_dist = self._compute_edit_distance(original_code, current_code)
        return JOResult(
            outcome=JOOutcome.ESCALATE,
            output_code=current_code,
            iterations=self.max_iterations,
            violations_fixed=total_fixed,
            edit_distance=edit_dist
        )

    def _apply_edits(self, code: str, violations: List[StyleViolation]) -> Tuple[str, int]:
        """
        Apply edits to fix violations.

        Three-stage approach:
        1. Run Black + isort (fixes formatting + imports)
        2. Add missing docstrings (D1xx violations)
        3. Apply other precedent patches for naming
        """
        fixed = 0

        # Stage 1: Formatter pass
        code = self._run_black(code)
        code = self._run_isort(code)

        # Count formatter fixes
        formatter_violations = {'E', 'I', 'BLACK'}
        for v in violations:
            if v.rule[0] in formatter_violations or v.rule == 'BLACK':
                fixed += 1

        # Stage 2: Add ALL missing docstrings at once (more robust than line-by-line)
        docstring_violations = [v for v in violations if v.rule.startswith('D1')]
        if docstring_violations:
            new_code = self._add_all_docstrings(code)
            if new_code != code:
                code = new_code
                # Re-run Black after AST transformation (ast.unparse changes formatting)
                code = self._run_black(code)
                fixed += len(docstring_violations)

        # Stage 3: Other precedent-guided patches (naming, etc.)
        remaining = [v for v in violations
                     if v.rule[0] not in formatter_violations and not v.rule.startswith('D1')]

        for v in remaining:
            if v.rule in self.precedents:
                precedent = self.precedents[v.rule]
                new_code = self._apply_precedent(code, v, precedent)
                if new_code != code:
                    code = new_code
                    fixed += 1

        return code, fixed

    def _run_black(self, code: str) -> str:
        """Run Black formatter."""
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_path = f.name

        try:
            subprocess.run(
                ['black', '--quiet', '--line-length', str(self.line_length), temp_path],
                capture_output=True
            )
            with open(temp_path) as f:
                return f.read()
        except:
            return code
        finally:
            os.unlink(temp_path)

    def _run_isort(self, code: str) -> str:
        """Run isort for import sorting."""
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_path = f.name

        try:
            subprocess.run(
                ['isort', '--quiet', '--line-length', str(self.line_length), temp_path],
                capture_output=True
            )
            with open(temp_path) as f:
                return f.read()
        except:
            return code
        finally:
            os.unlink(temp_path)

    def _add_all_docstrings(self, code: str) -> str:
        """Add template docstrings to ALL functions/classes/module missing them."""
        try:
            tree = ast.parse(code)

            class DocstringAdder(ast.NodeTransformer):
                def _has_docstring(self, node):
                    """Check if node already has a docstring."""
                    if not node.body:
                        return False
                    first = node.body[0]
                    return (isinstance(first, ast.Expr) and
                            isinstance(first.value, ast.Constant) and
                            isinstance(first.value.value, str))

                def _make_docstring(self, text):
                    """Create a docstring AST node."""
                    return ast.Expr(value=ast.Constant(value=text))

                def visit_Module(self, node):
                    self.generic_visit(node)
                    if not self._has_docstring(node) and node.body:
                        docstring = self._make_docstring("Module docstring.")
                        node.body.insert(0, docstring)
                    return node

                def visit_FunctionDef(self, node):
                    self.generic_visit(node)
                    if not self._has_docstring(node):
                        # Generate Google-style docstring
                        func_name = node.name.replace('_', ' ').strip()
                        args = [a.arg for a in node.args.args if a.arg != 'self']
                        if args:
                            args_lines = "\n".join(f"        {a}: Parameter description." for a in args)
                            doc_text = f"{func_name.title()}.\n\n    Args:\n{args_lines}\n    "
                        else:
                            doc_text = f"{func_name.title()}."
                        docstring = self._make_docstring(doc_text)
                        node.body.insert(0, docstring)
                    return node

                def visit_AsyncFunctionDef(self, node):
                    return self.visit_FunctionDef(node)

                def visit_ClassDef(self, node):
                    self.generic_visit(node)
                    if not self._has_docstring(node):
                        doc_text = f"{node.name} class."
                        docstring = self._make_docstring(doc_text)
                        node.body.insert(0, docstring)
                    return node

            tree = DocstringAdder().visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except:
            return code

    def _apply_precedent(self, code: str, violation: StyleViolation,
                         precedent: Precedent) -> str:
        """Apply a precedent-based patch for a specific violation."""

        if precedent.patch_type == "docstring":
            return self._add_all_docstrings(code)
        elif precedent.patch_type == "rename":
            return self._fix_naming(code, violation)
        else:
            return code

    def _fix_naming(self, code: str, violation: StyleViolation) -> str:
        """Fix naming convention violations (convert to snake_case)."""
        try:
            tree = ast.parse(code)

            class NamingFixer(ast.NodeTransformer):
                def __init__(self, target_line):
                    self.target_line = target_line
                    self.renames = {}

                def _to_snake_case(self, name):
                    import re
                    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
                    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()

                def visit_FunctionDef(self, node):
                    self.generic_visit(node)
                    if node.lineno == self.target_line:
                        new_name = self._to_snake_case(node.name)
                        if new_name != node.name:
                            self.renames[node.name] = new_name
                            node.name = new_name
                    return node

                def visit_Name(self, node):
                    if node.id in self.renames:
                        node.id = self.renames[node.id]
                    return node

            fixer = NamingFixer(violation.line)
            tree = fixer.visit(tree)
            ast.fix_missing_locations(tree)
            return ast.unparse(tree)
        except:
            return code

    def _is_valid_python(self, code: str) -> bool:
        """Check if code is syntactically valid Python."""
        try:
            ast.parse(code)
            return True
        except:
            return False

    def _compute_edit_distance(self, original: str, modified: str) -> int:
        """Compute token-level edit distance."""
        orig_tokens = original.split()
        mod_tokens = modified.split()

        # Simple approximation: count different tokens
        diff = 0
        max_len = max(len(orig_tokens), len(mod_tokens))
        for i in range(max_len):
            if i >= len(orig_tokens) or i >= len(mod_tokens):
                diff += 1
            elif orig_tokens[i] != mod_tokens[i]:
                diff += 1
        return diff


if __name__ == '__main__':
    # Test the JO
    messy_code = '''
def calculateAverage(x):
    if not x:
        return 0.0
    t=sum(x)
    n=len(x)
    return t/n

class myClass:
    def doSomething(self,val):
        return val*2
'''

    jo = StyleJO()
    result = jo.process(messy_code)

    print(f"Outcome: {result.outcome.value}")
    print(f"Iterations: {result.iterations}")
    print(f"Violations fixed: {result.violations_fixed}")
    print(f"Edit distance: {result.edit_distance}")
    print("\n=== Output Code ===")
    print(result.output_code)
