"""
Protocol Operator for Multi-Agent Judgment Projecting.

JO intervention layer that validates messages between agents
and enforces protocol contracts at message boundaries.

Key responsibilities:
1. Track browser navigation (visited URLs, page content)
2. Validate writer output against schema contract
3. Block/retry malformed messages with feedback
"""

import json
import re
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field

try:
    from .protocol_evaluator import ProtocolEvaluator, REQUIRED_FIELDS
except ImportError:
    from protocol_evaluator import ProtocolEvaluator, REQUIRED_FIELDS


@dataclass
class InterventionResult:
    """Result of a operator intervention."""
    action: str  # "pass", "block", "retry"
    violations: List[str] = field(default_factory=list)
    feedback: Optional[str] = None
    modified_content: Optional[str] = None


class ProtocolOperator:
    """
    JO Protocol Operator for multi-agent coordination.

    Sits between agents and validates messages against protocol contracts.
    Provides structured feedback for self-correction.
    """

    def __init__(
        self,
        k_required: int = 3,
        retry_budget: int = 3,
        fuzzy_quote_threshold: float = 0.8,
        min_quote_length: int = 10,
        strict_mode: bool = True
    ):
        """
        Args:
            k_required: Number of claims required
            retry_budget: Max retries before passing through
            fuzzy_quote_threshold: Word overlap for fuzzy quote matching
            min_quote_length: Minimum quote characters
            strict_mode: If True, block on any violation. If False, only block critical.
        """
        self.k_required = k_required
        self.retry_budget = retry_budget
        self.strict_mode = strict_mode

        self.evaluator = ProtocolEvaluator(
            fuzzy_quote_threshold=fuzzy_quote_threshold,
            min_quote_length=min_quote_length
        )

        # State tracking
        self.visited_urls: List[str] = []
        self.page_cache: Dict[str, str] = {}
        self.retry_count = 0
        self.intervention_log: List[Dict] = []

    def reset(self):
        """Reset state for new task."""
        self.visited_urls = []
        self.page_cache = {}
        self.retry_count = 0
        self.intervention_log = []

    def register_browser_observation(
        self,
        url: str,
        content: Optional[str] = None,
        action_type: Optional[str] = None
    ):
        """
        Register browser navigation for URL/content tracking.

        Args:
            url: URL visited
            content: Page content (if available)
            action_type: Type of action (goto, get_content, etc.)
        """
        if url and url not in self.visited_urls:
            self.visited_urls.append(url)

        if url and content:
            self.page_cache[url] = content

        # Update evaluator state
        self.evaluator.visited_urls = set(self.visited_urls)
        self.evaluator.page_cache = self.page_cache

    def _generate_feedback(self, violations: List[str]) -> str:
        """Generate human-readable feedback for violations."""
        feedback_parts = ["PROTOCOL VIOLATION - Please fix and retry:\n"]

        # Group violations by type
        quote_issues = []
        url_issues = []
        schema_issues = []
        count_issues = []

        for v in violations:
            if "quote_not_found" in v:
                quote_issues.append(v)
            elif "url_not_visited" in v:
                url_issues.append(v)
            elif "missing_field" in v or "empty_field" in v:
                schema_issues.append(v)
            elif "count_error" in v:
                count_issues.append(v)

        if count_issues:
            feedback_parts.append(
                f"- WRONG COUNT: Must have exactly {self.k_required} claims"
            )

        if schema_issues:
            missing = [v.split(":")[-1] for v in schema_issues if "missing" in v]
            if missing:
                feedback_parts.append(
                    f"- MISSING FIELDS: Each claim needs: {', '.join(set(missing))}"
                )

        if url_issues:
            feedback_parts.append(
                f"- URL ERROR: source_url must be from visited URLs: {self.visited_urls}"
            )

        if quote_issues:
            feedback_parts.append(
                "- QUOTE ERROR: quote_span must be EXACT text from the source page. "
                "Copy verbatim, do not paraphrase."
            )

        feedback_parts.append(f"\nRequired fields per claim: {REQUIRED_FIELDS}")

        return "\n".join(feedback_parts)

    def validate_writer_output(
        self,
        output: str
    ) -> Tuple[bool, List[str], Optional[str]]:
        """
        Validate writer output against protocol.

        Args:
            output: Raw writer output string

        Returns:
            (is_valid, violations, feedback_message)
        """
        metrics = self.evaluator.compute_metrics(output, self.k_required)

        if metrics['valid']:
            return True, [], None

        violations = metrics['violations']
        feedback = self._generate_feedback(violations)

        return False, violations, feedback

    def intercept(
        self,
        message: Dict[str, Any],
        context: Optional[Dict[str, Any]] = None
    ) -> InterventionResult:
        """
        Main JO intercept point for messages.

        Args:
            message: Message dict with 'sender', 'content', etc.
            context: Optional context (task info, etc.)

        Returns:
            InterventionResult with action and feedback
        """
        sender = message.get("sender", "unknown")
        content = message.get("content", "")

        # Update k_required from context if provided
        if context and "k_claims" in context:
            self.k_required = context["k_claims"]

        # Track browser observations
        if sender == "browser":
            action = message.get("action", {})
            url = action.get("url") or message.get("url", "")
            page_content = action.get("content") or message.get("page_content", "")

            if url:
                self.register_browser_observation(url, page_content)

            # Log but pass through browser messages
            self.intervention_log.append({
                "sender": sender,
                "action": "pass",
                "reason": "browser_observation"
            })
            return InterventionResult(action="pass")

        # Validate writer output
        if sender == "writer":
            is_valid, violations, feedback = self.validate_writer_output(content)

            if is_valid:
                self.intervention_log.append({
                    "sender": sender,
                    "action": "pass",
                    "reason": "valid_output"
                })
                return InterventionResult(action="pass")

            # Invalid output - decide action
            self.retry_count += 1

            self.intervention_log.append({
                "sender": sender,
                "action": "retry" if self.retry_count <= self.retry_budget else "pass",
                "violations": violations,
                "retry_count": self.retry_count
            })

            if self.retry_count > self.retry_budget:
                # Exhausted retries, let it through to fail at evaluator
                return InterventionResult(
                    action="pass",
                    violations=violations,
                    feedback="Max retries exceeded, passing through."
                )

            # Request retry with feedback
            return InterventionResult(
                action="retry",
                violations=violations,
                feedback=feedback
            )

        # Unknown sender, pass through
        return InterventionResult(action="pass")

    def get_stats(self) -> Dict[str, Any]:
        """Get intervention statistics."""
        total_interventions = len([
            log for log in self.intervention_log
            if log.get("action") == "retry"
        ])

        return {
            "total_messages": len(self.intervention_log),
            "interventions": total_interventions,
            "retry_count": self.retry_count,
            "visited_urls": len(self.visited_urls),
            "cached_pages": len(self.page_cache)
        }


class NoOperator:
    """Baseline: No projecting, pass everything through."""

    def __init__(self, **kwargs):
        self.intervention_log = []

    def reset(self):
        self.intervention_log = []

    def register_browser_observation(self, url: str, content: Optional[str] = None, **kwargs):
        pass

    def intercept(self, message: Dict[str, Any], context: Optional[Dict] = None) -> InterventionResult:
        return InterventionResult(action="pass")

    def get_stats(self) -> Dict[str, Any]:
        return {"total_messages": 0, "interventions": 0}


def create_operator(
    method: str,
    k_required: int = 3,
    **kwargs
) -> Any:
    """
    Factory function to create operator by method name.

    Args:
        method: "NO" for no operator, "JO_protocol" for protocol operator
        k_required: Number of claims required
        **kwargs: Additional operator arguments

    Returns:
        Operator instance
    """
    if method in ("NO", "no_operator", "NO_fewshot"):
        return NoOperator()
    elif method in ("JO_protocol", "JO_static", "JO_dynamic", "JO_protocol"):
        return ProtocolOperator(k_required=k_required, **kwargs)
    else:
        raise ValueError(f"Unknown operator method: {method}")


# Test cases
if __name__ == "__main__":
    # Setup operator
    operator = ProtocolOperator(k_required=3, retry_budget=2)

    # Simulate browser observation
    operator.register_browser_observation(
        url="http://localhost:9999/wiki/Albert_Einstein",
        content="Albert Einstein was born in Ulm on 14 March 1879. He graduated from ETH Zurich."
    )

    # Test 1: Valid writer output
    valid_output = json.dumps({
        "claims": [
            {
                "entity": "Albert Einstein",
                "predicate": "born_in",
                "value": "Ulm",
                "source_url": "http://localhost:9999/wiki/Albert_Einstein",
                "quote_span": "Albert Einstein was born in Ulm on 14 March 1879",
                "confidence": 0.95
            },
            {
                "entity": "Albert Einstein",
                "predicate": "birth_date",
                "value": "14 March 1879",
                "source_url": "http://localhost:9999/wiki/Albert_Einstein",
                "quote_span": "born in Ulm on 14 March 1879",
                "confidence": 0.90
            },
            {
                "entity": "Albert Einstein",
                "predicate": "graduated_from",
                "value": "ETH Zurich",
                "source_url": "http://localhost:9999/wiki/Albert_Einstein",
                "quote_span": "He graduated from ETH Zurich",
                "confidence": 0.85
            }
        ]
    })

    result = operator.intercept({"sender": "writer", "content": valid_output})
    print(f"Valid output: action={result.action}")
    assert result.action == "pass"

    # Test 2: Invalid output (wrong URL)
    operator.reset()
    operator.register_browser_observation(
        url="http://localhost:9999/wiki/Albert_Einstein",
        content="Albert Einstein was born in Ulm."
    )

    invalid_output = json.dumps({
        "claims": [{
            "entity": "Einstein",
            "predicate": "born_in",
            "value": "Ulm",
            "source_url": "http://localhost:9999/wiki/WRONG_URL",
            "quote_span": "Albert Einstein was born in Ulm",
            "confidence": 0.9
        }]
    })

    result = operator.intercept({"sender": "writer", "content": invalid_output})
    print(f"Invalid URL: action={result.action}, violations={result.violations[:2]}")
    assert result.action == "retry"
    assert "url_not_visited" in str(result.violations)

    # Test 3: Retry exhaustion
    for _ in range(3):
        result = operator.intercept({"sender": "writer", "content": invalid_output})

    print(f"After retry exhaustion: action={result.action}")
    assert result.action == "pass"  # Passes through after budget exhausted

    print("\nAll tests passed!")
    print(f"Stats: {operator.get_stats()}")
