# precedent_store.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import re
import os

# Try OpenAI embeddings first, fall back to sentence-transformers
_openai_client = None
try:
    from openai import OpenAI
    _openai_client = OpenAI()
except Exception:
    pass

# Fallback to sentence-transformers if OpenAI not available
try:
    from sentence_transformers import SentenceTransformer, util as st_util
except Exception:
    SentenceTransformer = None
    st_util = None


def _tokenize(s: str) -> List[str]:
    return re.findall(r"[a-zA-Z0-9_]+", (s or "").lower())


def _jaccard(a: str, b: str) -> float:
    A = set(_tokenize(a))
    B = set(_tokenize(b))
    if not A and not B:
        return 1.0
    if not A or not B:
        return 0.0
    return len(A & B) / len(A | B)


def _state_summary(state: Dict[str, Any]) -> str:
    # keep it small + stable; add more features later
    parts = []
    for k in ["scenario", "site", "intent", "failure_mode", "last_action_type"]:
        if k in state and state[k] is not None:
            parts.append(f"{k}={state[k]}")
    if "budget_left" in state and state["budget_left"] is not None:
        parts.append(f"budget_left={state['budget_left']}")
    if "last_obs" in state and state["last_obs"]:
        parts.append(f"last_obs={str(state['last_obs'])[:200]}")
    return "\n".join(parts)


@dataclass
class RetrievalHit:
    key: str
    score: float
    action: Dict[str, Any]
    source_task_id: Optional[int] = None  # Task ID where precedent was created (for CTHR)


class PrecedentStore:
    """
    Stores (state, approved_action) and retrieves nearest precedent.
    Supports:
      - semantic similarity (SentenceTransformer cosine)
      - token Jaccard fallback
    """

    def __init__(
        self,
        *,
        use_semantic: bool = True,
        retrieve_threshold: float = 0.65,
        model_name: str = "text-embedding-3-small",
        max_capacity: int = 100,
        eviction_policy: str = "fifo",
    ):
        self.use_semantic = bool(use_semantic)
        self.retrieve_threshold = float(retrieve_threshold)
        self.max_capacity = int(max_capacity)
        self.eviction_policy = eviction_policy
        self.model_name = model_name

        self._items: List[Tuple[str, Dict[str, Any], str, Optional[int]]] = []  # (key, action, text_for_match, task_id)
        self._embs = None  # numpy array of embeddings

        # Pattern statistics for theta learning: pattern_id -> {count, success, fail}
        self._pattern_stats: Dict[str, Dict[str, int]] = {}

        self._model = None
        self._use_openai = False

        if self.use_semantic:
            # Prefer OpenAI embeddings
            if _openai_client is not None and model_name.startswith("text-embedding"):
                self._use_openai = True
            elif SentenceTransformer is not None:
                # Fall back to sentence-transformers for legacy model names
                self._model = SentenceTransformer(model_name)
            else:
                # No embedding model available, fall back to jaccard
                self.use_semantic = False

    def _get_openai_embeddings(self, texts: List[str]) -> np.ndarray:
        """Get embeddings from OpenAI API."""
        if not texts:
            return np.array([])
        response = _openai_client.embeddings.create(
            model=self.model_name,
            input=texts
        )
        embeddings = [item.embedding for item in response.data]
        return np.array(embeddings)
    
    def _key_text(self, *, state: Dict[str, Any], action_text: str) -> str:
        parts = []
        for k in ["scenario", "site", "intent", "failure_mode", "budget_left", "t"]:
            if k in state:
                parts.append(f"{k}={state[k]}")
        last_obs = state.get("last_obs", "") or ""
        if last_obs:
            parts.append(f"last_obs={last_obs[:300]}")  # Truncate
        constraint = state.get("constraint_text", "")  # Pass constraint text in context
        if constraint:
            parts.append(f"constraint={constraint[:200]}")
        parts.append(f"action={action_text}")
        return "\n".join(parts)

    def add(
        self,
        *,
        site: str,
        intent: str,
        failure_mode: str,
        approved_action: Dict[str, Any],
        bad_action_text: str,
        state: Optional[Dict[str, Any]] = None,
        task_id: Optional[int] = None,
    ) -> None:
        st = _state_summary(state or {"site": site, "intent": intent, "failure_mode": failure_mode})
        action_text = approved_action.get("text", str(approved_action))

        text_for_match = f"{st}\nACTION:\n{action_text}"
        key = f"{site}|{intent}|{failure_mode}"
        self._items.append((key, dict(approved_action), text_for_match, task_id))

        # Evict oldest if over capacity
        if len(self._items) > self.max_capacity:
            if self.eviction_policy == "fifo":
                self._items.pop(0)  # Remove oldest
            elif self.eviction_policy == "random":
                import random
                idx = random.randint(0, len(self._items) - 1)
                self._items.pop(idx)

        if self.use_semantic:
            # recompute incrementally (small store => ok)
            texts = [it[2] for it in self._items]
            if self._use_openai:
                self._embs = self._get_openai_embeddings(texts)
                # Normalize embeddings
                norms = np.linalg.norm(self._embs, axis=1, keepdims=True)
                self._embs = self._embs / np.clip(norms, 1e-8, None)
            else:
                self._embs = self._model.encode(texts, convert_to_tensor=True, normalize_embeddings=True)

    def retrieve(
        self,
        state: Dict[str, Any],
        action_text: str,
        top_k: int = 1,
        current_task_id: Optional[int] = None,
    ) -> Optional[RetrievalHit]:
        if not self._items:
            return None

        query = f"{_state_summary(state)}\nBAD ACTION:\n{action_text}"

        if self.use_semantic:
            if self._use_openai:
                q = self._get_openai_embeddings([query])
                q = q / np.clip(np.linalg.norm(q, axis=1, keepdims=True), 1e-8, None)
                sims = np.dot(self._embs, q.T).flatten()  # cosine similarity
                best_i = int(np.argmax(sims))
                best_score = float(sims[best_i])
            else:
                q = self._model.encode([query], convert_to_tensor=True, normalize_embeddings=True)
                sims = st_util.cos_sim(q, self._embs)[0]  # shape: [N]
                best_i = int(sims.argmax().item())
                best_score = float(sims[best_i].item())
        else:
            scores = [_jaccard(query, it[2]) for it in self._items]
            best_i = int(np.argmax(scores))
            best_score = float(scores[best_i])

        key, action, _, source_task_id = self._items[best_i]
        return RetrievalHit(key=key, score=best_score, action=action, source_task_id=source_task_id)

    def shuffle(self) -> None:
        """
        Shuffle the precedent store order.

        This is used as a control to test if temporal ordering matters for learning.
        After shuffling, retrieval still works but the temporal learning signal is destroyed.
        """
        import random
        if self._items:
            random.shuffle(self._items)

            # Recompute embeddings if using semantic similarity
            if self.use_semantic:
                texts = [it[2] for it in self._items]
                if self._use_openai:
                    self._embs = self._get_openai_embeddings(texts)
                    norms = np.linalg.norm(self._embs, axis=1, keepdims=True)
                    self._embs = self._embs / np.clip(norms, 1e-8, None)
                elif self._model:
                    self._embs = self._model.encode(texts, convert_to_tensor=True, normalize_embeddings=True)

    def size(self) -> int:
        """Return number of precedents in store."""
        return len(self._items)

    # ========== Pattern Statistics for Theta Learning ==========

    def record_pattern_outcome(self, pattern_id: str, success: bool) -> None:
        """
        Record success/failure outcome for a repair pattern.
        Used for computing f3 (confidence) feature in theta scoring.

        Args:
            pattern_id: Canonical pattern identifier (e.g., "format|static_fix")
            success: Whether the repair was successful
        """
        if not pattern_id:
            return
        if pattern_id not in self._pattern_stats:
            self._pattern_stats[pattern_id] = {"count": 0, "success": 0, "fail": 0}
        self._pattern_stats[pattern_id]["count"] += 1
        if success:
            self._pattern_stats[pattern_id]["success"] += 1
        else:
            self._pattern_stats[pattern_id]["fail"] += 1

    def get_pattern_stats(self, pattern_id: str) -> Dict[str, int]:
        """
        Get statistics for a repair pattern.

        Args:
            pattern_id: Canonical pattern identifier

        Returns:
            Dict with keys: count, success, fail (all default to 0 if not found)
        """
        return self._pattern_stats.get(pattern_id, {"count": 0, "success": 0, "fail": 0})

    def get_all_pattern_stats(self) -> Dict[str, Dict[str, int]]:
        """Export all pattern statistics for analysis."""
        return dict(self._pattern_stats)

    def clear(self) -> None:
        """Clear all precedents from store (for no-persistence control)."""
        print(f"    [PRECEDENT_STORE.clear()] Clearing {len(self._items)} items from store id={id(self)}")
        self._items = []
        self._embs = None
        self._pattern_stats = {}

    def export_to_dict(self) -> List[Dict[str, Any]]:
        """
        Export all precedents to JSON-serializable format.

        Returns:
            List of precedent dictionaries for transfer learning.
        """
        exported = []
        for item in self._items:
            key, action, text_for_match = item[0], item[1], item[2]
            task_id = item[3] if len(item) > 3 else None
            parts = key.split("|")
            if len(parts) == 3:
                site, intent, failure_mode = parts
            else:
                site, intent, failure_mode = key, "", ""

            exported.append({
                "site": site,
                "intent": intent,
                "failure_mode": failure_mode,
                "approved_action": action,
                "text_for_match": text_for_match,
                "task_id": task_id
            })
        return exported

    def import_from_dict(self, precedents: List[Dict[str, Any]], source_task_id: Optional[int] = None) -> None:
        """
        Import precedents from JSON format (for transfer learning).

        Args:
            precedents: List of precedent dictionaries from export_to_dict()
            source_task_id: Optional task_id to assign to imported precedents (for transfer)
        """
        for p in precedents:
            # Extract fields with defaults
            site = p.get("site", "unknown")
            intent = p.get("intent", "unknown")
            failure_mode = p.get("failure_mode", "violation")
            # Support both "approved_action" and "action" keys for backwards compatibility
            approved_action = p.get("approved_action") or p.get("action", {})
            text_for_match = p.get("text_for_match", "")
            task_id = p.get("task_id", source_task_id)

            # Build state dict from available info
            state = p.get("state", {})
            if not state:
                state = {
                    "site": site,
                    "intent": intent,
                    "failure_mode": failure_mode
                }

            self.add(
                site=site,
                intent=intent,
                failure_mode=failure_mode,
                approved_action=approved_action,
                bad_action_text=text_for_match,
                state=state,
                task_id=task_id
            )