"""
Protocol Evaluator for Structured Claim Extraction.

Validates writer output against strict schema contracts:
- Required fields present
- Exact claim count
- Quote spans match source text
- URLs were actually visited

All checks are deterministic - no LLM judgment required.
"""

import json
import re
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, field


REQUIRED_FIELDS = ["entity", "predicate", "value", "source_url", "quote_span", "confidence"]


@dataclass
class ValidationResult:
    """Result of validating a writer response."""
    valid: bool
    violations: List[str] = field(default_factory=list)
    metrics: Dict[str, Any] = field(default_factory=dict)


class ProtocolEvaluator:
    """
    Evaluator for structured claim extraction protocol.

    Validates claims against:
    1. Schema compliance (all required fields)
    2. Count compliance (exactly K claims)
    3. Quote verification (quote_span in source text)
    4. URL verification (source_url was visited)
    """

    def __init__(
        self,
        page_cache: Optional[Dict[str, str]] = None,
        visited_urls: Optional[List[str]] = None,
        fuzzy_quote_threshold: float = 0.8,
        min_quote_length: int = 10
    ):
        """
        Args:
            page_cache: {url: page_text} mapping from browser observations
            visited_urls: List of URLs browser actually visited
            fuzzy_quote_threshold: Word overlap threshold for fuzzy quote matching
            min_quote_length: Minimum characters for valid quote
        """
        self.page_cache = page_cache or {}
        self.visited_urls = set(visited_urls or [])
        self.fuzzy_quote_threshold = fuzzy_quote_threshold
        self.min_quote_length = min_quote_length

    def add_visited_url(self, url: str, content: Optional[str] = None):
        """Register a URL as visited, optionally with page content."""
        self.visited_urls.add(url)
        if content:
            self.page_cache[url] = content

    def _normalize_text(self, text: str) -> str:
        """Normalize text for comparison."""
        return re.sub(r'\s+', ' ', text.lower().strip())

    def _check_quote_in_source(self, quote: str, source_url: str) -> Tuple[bool, str]:
        """
        Check if quote exists in source page.

        Returns:
            (found, reason)
        """
        if source_url not in self.page_cache:
            return False, "source_not_cached"

        page_text = self._normalize_text(self.page_cache[source_url])
        quote_normalized = self._normalize_text(quote)

        # Exact substring match
        if quote_normalized in page_text:
            return True, "exact_match"

        # Fuzzy match: check word overlap
        quote_words = set(quote_normalized.split())
        page_words = set(page_text.split())

        if not quote_words:
            return False, "empty_quote"

        overlap = len(quote_words & page_words) / len(quote_words)
        if overlap >= self.fuzzy_quote_threshold:
            return True, f"fuzzy_match_{overlap:.2f}"

        return False, f"no_match_overlap_{overlap:.2f}"

    def validate_claim(self, claim: Dict, index: int = 0) -> Tuple[bool, List[str]]:
        """
        Validate a single claim against protocol.

        Args:
            claim: Claim dictionary
            index: Claim index for error messages

        Returns:
            (is_valid, list_of_violations)
        """
        violations = []
        prefix = f"claim_{index}"

        # Check required fields
        for field in REQUIRED_FIELDS:
            if field not in claim:
                violations.append(f"{prefix}:missing_field:{field}")
            elif claim[field] is None:
                violations.append(f"{prefix}:null_field:{field}")
            elif isinstance(claim[field], str) and not claim[field].strip():
                violations.append(f"{prefix}:empty_field:{field}")

        # Early return if missing critical fields
        if any("missing_field" in v for v in violations):
            return False, violations

        # Check confidence type and range
        conf = claim.get("confidence")
        if conf is not None:
            if not isinstance(conf, (int, float)):
                violations.append(f"{prefix}:invalid_type:confidence")
            elif not (0 <= conf <= 1):
                violations.append(f"{prefix}:confidence_out_of_range:{conf}")

        # Check URL was visited
        source_url = claim.get("source_url", "")
        if source_url and source_url not in self.visited_urls:
            violations.append(f"{prefix}:url_not_visited:{source_url}")

        # Check quote length
        quote = claim.get("quote_span", "")
        if quote and len(quote.strip()) < self.min_quote_length:
            violations.append(f"{prefix}:quote_too_short:{len(quote)}")

        # Check quote exists in source
        if quote and source_url:
            found, reason = self._check_quote_in_source(quote, source_url)
            if not found:
                violations.append(f"{prefix}:quote_not_found:{reason}")

        return len(violations) == 0, violations

    def validate_response(
        self,
        response: Any,
        k_required: int
    ) -> ValidationResult:
        """
        Validate full response against protocol contract.

        Args:
            response: Parsed JSON response or raw string
            k_required: Required number of claims

        Returns:
            ValidationResult with validity and violations
        """
        all_violations = []

        # Handle string input
        if isinstance(response, str):
            try:
                # Try to extract JSON from markdown code blocks
                json_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', response)
                if json_match:
                    response = json.loads(json_match.group(1))
                else:
                    # Try direct JSON parse
                    json_match = re.search(r'\{[\s\S]*\}', response)
                    if json_match:
                        response = json.loads(json_match.group())
                    else:
                        return ValidationResult(
                            valid=False,
                            violations=["parse_error:no_json_found"]
                        )
            except json.JSONDecodeError as e:
                return ValidationResult(
                    valid=False,
                    violations=[f"parse_error:invalid_json:{str(e)[:50]}"]
                )

        # Check top-level structure
        if not isinstance(response, dict):
            return ValidationResult(
                valid=False,
                violations=["structure_error:not_a_dict"]
            )

        if "claims" not in response:
            return ValidationResult(
                valid=False,
                violations=["structure_error:missing_claims_array"]
            )

        claims = response["claims"]

        if not isinstance(claims, list):
            return ValidationResult(
                valid=False,
                violations=["structure_error:claims_not_array"]
            )

        # Check count
        actual_count = len(claims)
        if actual_count != k_required:
            all_violations.append(
                f"count_error:expected_{k_required}_got_{actual_count}"
            )

        # Validate each claim
        seen_claims = set()
        for i, claim in enumerate(claims):
            if not isinstance(claim, dict):
                all_violations.append(f"claim_{i}:not_a_dict")
                continue

            valid, violations = self.validate_claim(claim, i)
            all_violations.extend(violations)

            # Check for duplicates
            claim_hash = (
                f"{claim.get('entity', '')}|"
                f"{claim.get('predicate', '')}|"
                f"{claim.get('value', '')}"
            )
            if claim_hash in seen_claims:
                all_violations.append(f"claim_{i}:duplicate")
            seen_claims.add(claim_hash)

        return ValidationResult(
            valid=len(all_violations) == 0,
            violations=all_violations
        )

    def compute_metrics(
        self,
        response: Any,
        k_required: int
    ) -> Dict[str, Any]:
        """
        Compute detailed metrics for analysis.

        Args:
            response: Response to evaluate
            k_required: Required claim count

        Returns:
            Dict with validity, violation counts, and breakdown
        """
        result = self.validate_response(response, k_required)

        # Categorize violations
        violations_by_type = {}
        for v in result.violations:
            # Extract violation type (second part after colon)
            parts = v.split(":")
            if len(parts) >= 2:
                vtype = parts[1] if parts[0].startswith("claim_") else parts[0]
            else:
                vtype = v
            violations_by_type[vtype] = violations_by_type.get(vtype, 0) + 1

        return {
            "valid": result.valid,
            "n_violations": len(result.violations),
            "violations": result.violations,
            "violations_by_type": violations_by_type,
            "has_parse_error": any("parse_error" in v for v in result.violations),
            "has_count_error": any("count_error" in v for v in result.violations),
            "has_quote_error": any("quote_not_found" in v for v in result.violations),
            "has_url_error": any("url_not_visited" in v for v in result.violations),
            "has_schema_error": any("missing_field" in v or "empty_field" in v for v in result.violations),
        }


def evaluate_protocol_task(
    writer_output: str,
    k_required: int,
    page_cache: Dict[str, str],
    visited_urls: List[str]
) -> Dict[str, Any]:
    """
    Convenience function to evaluate a protocol task.

    Args:
        writer_output: Raw writer output string
        k_required: Number of claims required
        page_cache: {url: content} from browser
        visited_urls: URLs browser visited

    Returns:
        Evaluation metrics dict
    """
    evaluator = ProtocolEvaluator(
        page_cache=page_cache,
        visited_urls=visited_urls
    )
    return evaluator.compute_metrics(writer_output, k_required)


# Test cases
if __name__ == "__main__":
    # Setup test data
    test_page = """
    Albert Einstein was born in Ulm, in the Kingdom of Württemberg
    in the German Empire, on 14 March 1879. He graduated from ETH
    Zurich in 1900 with a degree in physics.
    """

    evaluator = ProtocolEvaluator(
        page_cache={"http://localhost:9999/wiki/Albert_Einstein": test_page},
        visited_urls=["http://localhost:9999/wiki/Albert_Einstein"]
    )

    # Test 1: Valid response
    valid_response = {
        "claims": [
            {
                "entity": "Albert Einstein",
                "predicate": "born_in",
                "value": "Ulm, German Empire",
                "source_url": "http://localhost:9999/wiki/Albert_Einstein",
                "quote_span": "Albert Einstein was born in Ulm, in the Kingdom of Württemberg",
                "confidence": 0.95
            },
            {
                "entity": "Albert Einstein",
                "predicate": "graduated_from",
                "value": "ETH Zurich",
                "source_url": "http://localhost:9999/wiki/Albert_Einstein",
                "quote_span": "He graduated from ETH Zurich in 1900",
                "confidence": 0.90
            },
            {
                "entity": "Albert Einstein",
                "predicate": "birth_date",
                "value": "14 March 1879",
                "source_url": "http://localhost:9999/wiki/Albert_Einstein",
                "quote_span": "on 14 March 1879",
                "confidence": 0.95
            }
        ]
    }

    result = evaluator.compute_metrics(valid_response, k_required=3)
    print(f"Valid response: valid={result['valid']}, violations={result['n_violations']}")
    assert result['valid'] == True, f"Expected valid, got violations: {result['violations']}"

    # Test 2: Wrong count
    wrong_count = {"claims": valid_response["claims"][:2]}
    result = evaluator.compute_metrics(wrong_count, k_required=3)
    print(f"Wrong count: valid={result['valid']}, has_count_error={result['has_count_error']}")
    assert result['has_count_error'] == True

    # Test 3: URL not visited
    bad_url = {"claims": [{
        **valid_response["claims"][0],
        "source_url": "http://localhost:9999/wiki/FAKE"
    }]}
    result = evaluator.compute_metrics(bad_url, k_required=1)
    print(f"Bad URL: valid={result['valid']}, has_url_error={result['has_url_error']}")
    assert result['has_url_error'] == True

    # Test 4: Quote not found
    bad_quote = {"claims": [{
        **valid_response["claims"][0],
        "quote_span": "This text does not exist in the source at all"
    }]}
    result = evaluator.compute_metrics(bad_quote, k_required=1)
    print(f"Bad quote: valid={result['valid']}, has_quote_error={result['has_quote_error']}")
    assert result['has_quote_error'] == True

    # Test 5: Missing field
    missing_field = {"claims": [{
        "entity": "Einstein",
        "predicate": "born_in",
        "value": "Ulm"
        # Missing: source_url, quote_span, confidence
    }]}
    result = evaluator.compute_metrics(missing_field, k_required=1)
    print(f"Missing fields: valid={result['valid']}, has_schema_error={result['has_schema_error']}")
    assert result['has_schema_error'] == True

    print("\nAll tests passed!")
