#!/usr/bin/env python3
"""
Style checker using Ruff and Black for deterministic constraint checking.

Defines the constraint set C for the code style experiment.
"""

import subprocess
import tempfile
import os
from typing import List, Tuple, Dict, Any, Optional
from dataclasses import dataclass


@dataclass
class StyleViolation:
    """A single style violation."""
    rule: str
    line: int
    column: int
    message: str
    fixable: bool = False


@dataclass
class CheckResult:
    """Result of style check."""
    passed: bool
    violations: List[StyleViolation]
    score: float  # 0-1 score based on violations


class StyleChecker:
    """
    Deterministic style checker using Ruff and Black.

    Constraint set C:
    - E: pycodestyle errors (formatting)
    - F: Pyflakes (unused imports, undefined names)
    - I: isort (import sorting)
    - N: pep8-naming
    - D: pydocstyle (docstrings) - subset
    """

    def __init__(self, max_line_length: int = 88):
        self.max_line_length = max_line_length
        # Ruff rules to check (Google-style approximation)
        self.ruff_rules = [
            "E",    # pycodestyle errors
            "F",    # Pyflakes
            "I",    # isort
            "N",    # pep8-naming
            "D1",   # Missing docstrings
            "D2",   # Docstring formatting
        ]

    def check(self, code: str) -> CheckResult:
        """
        Check code against style constraints.

        Returns CheckResult with pass/fail, violations list, and score.
        """
        violations = []

        # Check with pydocstyle (docstring violations)
        docstring_violations = self._check_pydocstyle(code)
        violations.extend(docstring_violations)

        # Check with flake8 (E, F, N rules)
        flake8_violations = self._check_flake8(code)
        violations.extend(flake8_violations)

        # Check with Black (format check)
        black_diff = self._check_black(code)
        if black_diff:
            violations.append(StyleViolation(
                rule="BLACK",
                line=0,
                column=0,
                message="Code does not match Black formatting",
                fixable=True
            ))

        # Calculate score (1.0 = perfect, 0.0 = many violations)
        # Use a simple formula: score = max(0, 1 - violations/20)
        score = max(0.0, 1.0 - len(violations) / 20.0)

        return CheckResult(
            passed=len(violations) == 0,
            violations=violations,
            score=score
        )

    def _check_pydocstyle(self, code: str) -> List[StyleViolation]:
        """Run pydocstyle for docstring violations."""
        violations = []

        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_path = f.name

        try:
            # Run pydocstyle with Google convention
            result = subprocess.run(
                ['pydocstyle', '--convention=google', temp_path],
                capture_output=True, text=True
            )

            # Parse output: each violation is 2 lines
            # Line 1: file:line message
            # Line 2: D### rule description
            lines = result.stdout.strip().split('\n') if result.stdout else []
            i = 0
            while i < len(lines):
                line = lines[i].strip()
                if ':' in line and not line.startswith('D'):
                    # Parse location line
                    try:
                        parts = line.split(':')
                        if len(parts) >= 2:
                            lineno = int(parts[1].split()[0]) if parts[1].split() else 0
                            # Next line has the rule
                            if i + 1 < len(lines):
                                rule_line = lines[i + 1].strip()
                                if rule_line.startswith('D'):
                                    rule_code = rule_line.split(':')[0].strip()
                                    message = rule_line.split(':', 1)[1].strip() if ':' in rule_line else rule_line
                                    violations.append(StyleViolation(
                                        rule=rule_code,
                                        line=lineno,
                                        column=0,
                                        message=message,
                                        fixable=True
                                    ))
                                    i += 1
                    except (ValueError, IndexError):
                        pass
                i += 1

        except FileNotFoundError:
            pass
        finally:
            os.unlink(temp_path)

        return violations

    def _check_flake8(self, code: str) -> List[StyleViolation]:
        """Run flake8 for E, F, N rules."""
        violations = []

        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_path = f.name

        try:
            result = subprocess.run(
                ['flake8', '--select=E,F,N', '--max-line-length', str(self.max_line_length),
                 '--format=%(code)s:%(row)d:%(col)d:%(text)s', temp_path],
                capture_output=True, text=True
            )

            for line in result.stdout.strip().split('\n'):
                if line and ':' in line:
                    try:
                        parts = line.split(':', 3)
                        if len(parts) >= 4:
                            violations.append(StyleViolation(
                                rule=parts[0],
                                line=int(parts[1]),
                                column=int(parts[2]),
                                message=parts[3],
                                fixable=True
                            ))
                    except (ValueError, IndexError):
                        pass

        except FileNotFoundError:
            pass
        finally:
            os.unlink(temp_path)

        return violations

    def _check_ruff(self, code: str) -> List[StyleViolation]:
        """Run Ruff and parse violations (if available)."""
        violations = []

        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_path = f.name

        try:
            # Run ruff check
            rules = ','.join(self.ruff_rules)
            result = subprocess.run(
                ['ruff', 'check', '--select', rules,
                 '--line-length', str(self.max_line_length),
                 '--output-format', 'json', temp_path],
                capture_output=True, text=True
            )

            if result.stdout:
                import json
                try:
                    ruff_output = json.loads(result.stdout)
                    for item in ruff_output:
                        violations.append(StyleViolation(
                            rule=item.get('code', 'UNKNOWN'),
                            line=item.get('location', {}).get('row', 0),
                            column=item.get('location', {}).get('column', 0),
                            message=item.get('message', ''),
                            fixable=item.get('fix') is not None
                        ))
                except json.JSONDecodeError:
                    pass

        except FileNotFoundError:
            # Ruff not installed, skip
            pass
        finally:
            os.unlink(temp_path)

        return violations

    def _check_black(self, code: str) -> Optional[str]:
        """Check if code matches Black formatting. Returns diff if not."""
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            temp_path = f.name

        try:
            result = subprocess.run(
                ['black', '--check', '--diff',
                 '--line-length', str(self.max_line_length), temp_path],
                capture_output=True, text=True
            )
            if result.returncode != 0:
                return result.stdout
            return None
        except FileNotFoundError:
            return None
        finally:
            os.unlink(temp_path)

    def get_violation_summary(self, result: CheckResult) -> Dict[str, int]:
        """Get count of violations by rule category."""
        summary = {}
        for v in result.violations:
            category = v.rule[0] if v.rule else 'OTHER'
            summary[category] = summary.get(category, 0) + 1
        return summary


class FunctionalChecker:
    """Check that code preserves functionality via tests."""

    def check(self, code: str, test_code: str) -> Tuple[bool, str]:
        """
        Run test_code against code and return (passed, error_message).

        The test_code should import/use functions from the code.
        """
        with tempfile.TemporaryDirectory() as tmpdir:
            # Write code to module
            code_path = os.path.join(tmpdir, 'module.py')
            with open(code_path, 'w') as f:
                f.write(code)

            # Write test
            test_path = os.path.join(tmpdir, 'test_module.py')
            full_test = f"import sys; sys.path.insert(0, '{tmpdir}')\n{test_code}"
            with open(test_path, 'w') as f:
                f.write(full_test)

            # Run test
            result = subprocess.run(
                ['python', test_path],
                capture_output=True, text=True,
                timeout=10
            )

            if result.returncode == 0:
                return True, ""
            else:
                return False, result.stderr or result.stdout


if __name__ == '__main__':
    # Test the checker
    good_code = '''
def calculate_average(numbers: list) -> float:
    """Calculate the average of a list of numbers.

    Args:
        numbers: List of numbers to average.

    Returns:
        The arithmetic mean.
    """
    if not numbers:
        return 0.0
    return sum(numbers) / len(numbers)
'''

    bad_code = '''
def calculateAverage(x):
    if not x:
        return 0.0
    t=sum(x)
    n=len(x)
    return t/n
'''

    checker = StyleChecker()

    print("=== Good Code ===")
    result = checker.check(good_code)
    print(f"Passed: {result.passed}, Score: {result.score:.2f}")
    print(f"Violations: {len(result.violations)}")

    print("\n=== Bad Code ===")
    result = checker.check(bad_code)
    print(f"Passed: {result.passed}, Score: {result.score:.2f}")
    print(f"Violations: {len(result.violations)}")
    for v in result.violations[:5]:
        print(f"  {v.rule}: {v.message}")
