"""
Strong baselines for comparison with JO.

Implements:
1. Constrained Decoding (Outlines-style grammar constraints)
2. Schema Validation + Repair (Pydantic-style)
3. Guardrails (rule-based policy engine)
4. Shield Controller (action blocking/filtering)
"""

import re
import json
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass


# =============================================================================
# 1. CONSTRAINED DECODING BASELINE
# =============================================================================
# Simulates Outlines/Guidance-style constrained decoding by:
# - Defining valid action grammar
# - Forcing output to match grammar
# - Falls back to "answer" action if invalid

@dataclass
class ActionGrammar:
    """Defines valid action types and their constraints."""
    valid_types: List[str] = None

    def __post_init__(self):
        if self.valid_types is None:
            self.valid_types = ["navigate", "click", "answer", "noop", "inform"]

    def is_valid(self, action: Dict[str, Any]) -> bool:
        """Check if action matches grammar."""
        if not isinstance(action, dict):
            return False
        action_type = action.get("action_type", "")
        return action_type in self.valid_types

    def force_valid(self, action: Dict[str, Any], default_type: str = "noop") -> Dict[str, Any]:
        """Force action to be valid by fixing or replacing."""
        if self.is_valid(action):
            return action
        # Try to fix
        if isinstance(action, dict):
            action = dict(action)
            if action.get("action_type") not in self.valid_types:
                action["action_type"] = default_type
            return action
        # Replace with default
        return {"action_type": default_type, "text": "Invalid action replaced"}


class ConstrainedDecodingBaseline:
    """
    Simulates Outlines/Guidance constrained decoding.

    In practice, this would constrain the LLM's token generation.
    We simulate by post-processing and enforcing format.
    """

    def __init__(self, format_pattern: str = None):
        self.grammar = ActionGrammar()
        # Format pattern for answers (e.g., pipe-separated)
        self.format_pattern = format_pattern or r".*\|.*"  # Contains pipe
        self.stats = {"total": 0, "format_fixes": 0, "type_fixes": 0}

    def process_action(self, action: Dict[str, Any], context: Dict[str, Any] = None) -> Tuple[Dict[str, Any], bool]:
        """
        Process action through constrained decoding.
        Returns (processed_action, was_modified).
        """
        self.stats["total"] += 1
        modified = False
        result = dict(action) if isinstance(action, dict) else {"action_type": "noop", "text": str(action)}

        # 1. Force valid action type
        if not self.grammar.is_valid(result):
            result = self.grammar.force_valid(result)
            self.stats["type_fixes"] += 1
            modified = True

        # 2. Force format for answer actions
        if result.get("action_type") == "answer" and self.format_pattern:
            text = result.get("text", "")
            if not re.search(self.format_pattern, text):
                # Try to convert to pipe format
                result["text"] = self._convert_to_format(text)
                self.stats["format_fixes"] += 1
                modified = True

        return result, modified

    def _convert_to_format(self, text: str) -> str:
        """Convert text to required format (pipe-separated)."""
        # Split by common delimiters and rejoin with pipes
        parts = re.split(r'[,;]\s*|\n+', text)
        if len(parts) > 1:
            return " | ".join(p.strip() for p in parts if p.strip())
        return text


# =============================================================================
# 2. SCHEMA VALIDATION + REPAIR (PYDANTIC-STYLE)
# =============================================================================

@dataclass
class ActionSchema:
    """Schema for valid actions."""
    required_fields: List[str] = None
    optional_fields: List[str] = None
    type_constraints: Dict[str, type] = None

    def __post_init__(self):
        if self.required_fields is None:
            self.required_fields = ["action_type"]
        if self.optional_fields is None:
            self.optional_fields = ["text", "url", "element_id"]
        if self.type_constraints is None:
            self.type_constraints = {
                "action_type": str,
                "text": str,
                "url": str,
            }


class SchemaValidationBaseline:
    """
    Pydantic-style schema validation with automatic repair.

    Validates action structure and auto-repairs common issues.
    """

    def __init__(self, require_citation: bool = False, require_format: bool = False):
        self.schema = ActionSchema()
        self.require_citation = require_citation
        self.require_format = require_format
        self.format_pattern = r".*\|.*"
        self.citation_pattern = r"(CITATION:|Source:|http://|https://)"
        self.stats = {"total": 0, "schema_fixes": 0, "citation_adds": 0, "format_fixes": 0}

    def validate_and_repair(self, action: Dict[str, Any], context: Dict[str, Any] = None) -> Tuple[Dict[str, Any], List[str]]:
        """
        Validate action against schema and repair if invalid.
        Returns (repaired_action, list_of_repairs).
        """
        self.stats["total"] += 1
        repairs = []
        result = dict(action) if isinstance(action, dict) else {}

        # 1. Ensure required fields
        for field in self.schema.required_fields:
            if field not in result:
                result[field] = self._get_default(field)
                repairs.append(f"added_missing_{field}")
                self.stats["schema_fixes"] += 1

        # 2. Type coercion
        for field, expected_type in self.schema.type_constraints.items():
            if field in result and not isinstance(result[field], expected_type):
                try:
                    result[field] = expected_type(result[field])
                    repairs.append(f"coerced_{field}")
                except:
                    result[field] = self._get_default(field)
                    repairs.append(f"reset_{field}")

        # 3. Citation requirement
        if self.require_citation and result.get("action_type") == "answer":
            text = result.get("text", "")
            if not re.search(self.citation_pattern, text):
                url = context.get("url", "http://source") if context else "http://source"
                result["text"] = f"{text} CITATION: {url}"
                repairs.append("added_citation")
                self.stats["citation_adds"] += 1

        # 4. Format requirement
        if self.require_format and result.get("action_type") == "answer":
            text = result.get("text", "")
            if not re.search(self.format_pattern, text):
                result["text"] = self._convert_to_pipe_format(text)
                repairs.append("fixed_format")
                self.stats["format_fixes"] += 1

        return result, repairs

    def _get_default(self, field: str) -> Any:
        defaults = {
            "action_type": "noop",
            "text": "",
            "url": "",
        }
        return defaults.get(field, "")

    def _convert_to_pipe_format(self, text: str) -> str:
        """Convert to pipe-separated format."""
        # Find key: value pairs
        pairs = re.findall(r'(\w+(?:\s+\w+)?)\s*[:=]\s*([^,;\n]+)', text)
        if pairs:
            return " | ".join(f"{k}: {v.strip()}" for k, v in pairs)
        # Fallback: split by sentences/commas
        parts = re.split(r'[.!?]\s+|,\s+', text)
        if len(parts) > 1:
            return " | ".join(p.strip() for p in parts if p.strip())
        return text


# =============================================================================
# 3. GUARDRAILS BASELINE (RULE-BASED POLICY ENGINE)
# =============================================================================

class GuardrailsBaseline:
    """
    NeMo Guardrails-style policy engine.

    Implements:
    - Input rails (check before LLM)
    - Output rails (check after LLM)
    - Topic rails (block off-topic)
    - Fact-checking rails (optional)
    """

    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.blocked_patterns = self.config.get("blocked_patterns", [
            r"(password|credit.card|ssn|secret)",  # PII
            r"(hack|exploit|attack)",  # Security
        ])
        self.required_patterns = self.config.get("required_patterns", {
            "answer": r".*",  # All answers must have text
        })
        self.role_constraints = self.config.get("role_constraints", {
            "RESEARCHER": ["navigate", "click", "inform"],
            "WRITER": ["answer", "inform"],
            "VERIFIER": ["answer", "inform", "noop"],
        })
        self.stats = {"total": 0, "blocked": 0, "modified": 0, "allowed": 0}

    def check_action(self, action: Dict[str, Any], context: Dict[str, Any] = None) -> Tuple[str, Optional[Dict[str, Any]], str]:
        """
        Check action against guardrails.
        Returns (decision, modified_action, reason).
        Decision: "ALLOW", "BLOCK", "MODIFY"
        """
        self.stats["total"] += 1
        context = context or {}
        action_type = action.get("action_type", "")
        text = action.get("text", "") or action.get("url", "")
        role = context.get("agent_role", "")

        # 1. Check blocked patterns
        for pattern in self.blocked_patterns:
            if re.search(pattern, text, re.IGNORECASE):
                self.stats["blocked"] += 1
                return "BLOCK", None, f"blocked_pattern:{pattern}"

        # 2. Check role constraints
        if role and role in self.role_constraints:
            allowed_types = self.role_constraints[role]
            if action_type not in allowed_types:
                # Modify to allowed action
                modified = dict(action)
                modified["action_type"] = "noop"
                modified["text"] = f"[ROLE VIOLATION] {role} cannot {action_type}"
                self.stats["modified"] += 1
                return "MODIFY", modified, f"role_violation:{role}:{action_type}"

        # 3. Check required patterns
        if action_type in self.required_patterns:
            pattern = self.required_patterns[action_type]
            if not re.search(pattern, text):
                self.stats["modified"] += 1
                modified = dict(action)
                modified["text"] = text or "[No content provided]"
                return "MODIFY", modified, "missing_required_pattern"

        self.stats["allowed"] += 1
        return "ALLOW", action, "ok"


# =============================================================================
# 4. SHIELD CONTROLLER (AEGIS-STYLE)
# =============================================================================

class ShieldController:
    """
    Aegis/Proper Shields-style controller for discrete action spaces.

    Implements:
    - Pre-shield: Block unsafe actions before execution
    - Post-shield: Verify executed actions didn't violate constraints
    - Safe fallback: Provide safe alternative when blocking
    """

    def __init__(self, constraint_checker=None):
        self.constraint_checker = constraint_checker or self._default_checker
        self.safe_action = {"action_type": "noop", "text": "Blocked by shield"}
        self.violation_history = []
        self.stats = {"total": 0, "blocked": 0, "allowed": 0, "repaired": 0}

    def _default_checker(self, action: Dict[str, Any], context: Dict[str, Any]) -> Tuple[bool, str]:
        """Default constraint checker."""
        # Check format constraint
        if action.get("action_type") == "answer":
            text = action.get("text", "")
            if context.get("require_format", False) and "|" not in text:
                return True, "format_violation"
            if context.get("require_citation", False):
                if not re.search(r"(CITATION|Source|http)", text):
                    return True, "citation_missing"

        # Check role constraint
        role = context.get("agent_role", "")
        if role == "RESEARCHER" and action.get("action_type") == "answer":
            return True, "role_leakage"

        return False, "ok"

    def pre_shield(self, action: Dict[str, Any], context: Dict[str, Any] = None) -> Tuple[Dict[str, Any], bool, str]:
        """
        Pre-execution shield. Returns (action, was_blocked, reason).
        If blocked, returns safe fallback action.
        """
        self.stats["total"] += 1
        context = context or {}

        violates, reason = self.constraint_checker(action, context)

        if violates:
            self.stats["blocked"] += 1
            self.violation_history.append({
                "action": action,
                "reason": reason,
                "context": context,
            })

            # Try to repair instead of blocking
            repaired = self._try_repair(action, reason, context)
            if repaired:
                self.stats["repaired"] += 1
                return repaired, False, f"repaired:{reason}"

            return self.safe_action, True, reason

        self.stats["allowed"] += 1
        return action, False, "ok"

    def _try_repair(self, action: Dict[str, Any], reason: str, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        """Try to repair action instead of blocking."""
        repaired = dict(action)

        if reason == "format_violation":
            text = repaired.get("text", "")
            # Convert to pipe format
            parts = re.split(r'[,;]\s*|\.\s+', text)
            if len(parts) > 1:
                repaired["text"] = " | ".join(p.strip() for p in parts if p.strip())
                return repaired

        elif reason == "citation_missing":
            url = context.get("url", "http://wikipedia.org")
            repaired["text"] = f"{repaired.get('text', '')} CITATION: {url}"
            return repaired

        elif reason == "role_leakage":
            # Convert answer to inform for wrong role
            repaired["action_type"] = "inform"
            repaired["text"] = f"[RESEARCHER→WRITER] {repaired.get('text', '')}"
            return repaired

        return None


# =============================================================================
# UNIFIED BASELINE INTERFACE
# =============================================================================

class BaselineOperator:
    """
    Unified interface for all baselines.
    Can combine multiple baseline strategies.
    """

    def __init__(self, baseline_type: str, config: Dict[str, Any] = None):
        self.baseline_type = baseline_type
        self.config = config or {}

        # Initialize appropriate baseline
        if baseline_type == "constrained_decoding":
            self.baseline = ConstrainedDecodingBaseline(
                format_pattern=self.config.get("format_pattern")
            )
        elif baseline_type == "schema_validation":
            self.baseline = SchemaValidationBaseline(
                require_citation=self.config.get("require_citation", False),
                require_format=self.config.get("require_format", False),
            )
        elif baseline_type == "guardrails":
            self.baseline = GuardrailsBaseline(config=self.config)
        elif baseline_type == "shield":
            self.baseline = ShieldController()
        else:
            raise ValueError(f"Unknown baseline type: {baseline_type}")

    def project(self, action: Dict[str, Any], context: Dict[str, Any] = None) -> Dict[str, Any]:
        """Project action through baseline. Returns processed action."""
        context = context or {}

        if self.baseline_type == "constrained_decoding":
            processed, _ = self.baseline.process_action(action, context)
            return {"outcome": "ALLOW", "action": processed, "was_modified": _}

        elif self.baseline_type == "schema_validation":
            processed, repairs = self.baseline.validate_and_repair(action, context)
            return {"outcome": "ALLOW", "action": processed, "repairs": repairs}

        elif self.baseline_type == "guardrails":
            decision, modified, reason = self.baseline.check_action(action, context)
            return {"outcome": decision, "action": modified or action, "reason": reason}

        elif self.baseline_type == "shield":
            processed, blocked, reason = self.baseline.pre_shield(action, context)
            outcome = "DENY" if blocked else "ALLOW"
            return {"outcome": outcome, "action": processed, "reason": reason}

    def get_stats(self) -> Dict[str, Any]:
        """Get baseline statistics."""
        return self.baseline.stats
