# learned_critic.py
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# Try OpenAI embeddings first
_openai_client = None
try:
    from openai import OpenAI
    _openai_client = OpenAI()
except Exception:
    pass

# Fallback to sentence-transformers
try:
    from sentence_transformers import SentenceTransformer
except Exception:
    SentenceTransformer = None


OUTCOMES = ["ALLOW", "EDIT", "ESCALATE", "DENY"]
OUTCOME_TO_ID = {o: i for i, o in enumerate(OUTCOMES)}
ID_TO_OUTCOME = {i: o for o, i in OUTCOME_TO_ID.items()}


@dataclass
class CriticPrediction:
    outcome: str
    confidence: float
    probs: List[float]


class LearnedCriticHead(nn.Module):
    """A tiny MLP classifier on top of frozen sentence embeddings."""
    def __init__(self, emb_dim: int = 1536, hidden: int = 256):
        super().__init__()
        self.proj = nn.Linear(emb_dim, hidden)
        self.classifier = nn.Linear(hidden, 4)

    def forward(self, emb: torch.Tensor) -> torch.Tensor:
        # emb: [D] or [B,D]
        if emb.dim() == 1:
            emb = emb.unsqueeze(0)
        h = F.relu(self.proj(emb))
        logits = self.classifier(h)
        return logits


class LearnedCritic:
    """
    Wrapper:
      - OpenAI or SentenceTransformer embedder (frozen)
      - trainable head
    """
    def __init__(
        self,
        model_name: str = "text-embedding-3-small",
        hidden: int = 256,
        device: Optional[str] = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self._use_openai = False
        self.embedder = None

        # Determine embedding dimension and setup
        if _openai_client is not None and model_name.startswith("text-embedding"):
            self._use_openai = True
            emb_dim = 1536  # text-embedding-3-small dimension
        elif SentenceTransformer is not None:
            self.embedder = SentenceTransformer(model_name)
            emb_dim = 384  # MiniLM dimension
        else:
            raise RuntimeError("No embedding model available (OpenAI or sentence-transformers)")

        self.head = LearnedCriticHead(emb_dim=emb_dim, hidden=hidden).to(self.device)

    def _get_openai_embedding(self, text: str) -> np.ndarray:
        """Get embedding from OpenAI API."""
        response = _openai_client.embeddings.create(
            model=self.model_name,
            input=[text]
        )
        return np.array(response.data[0].embedding)

    @torch.no_grad()
    def embed(self, text: str) -> torch.Tensor:
        if self._use_openai:
            emb = self._get_openai_embedding(text)
            emb = emb / np.clip(np.linalg.norm(emb), 1e-8, None)  # normalize
            return torch.tensor(emb, dtype=torch.float32).to(self.device)
        else:
            emb = self.embedder.encode(
                text,
                convert_to_tensor=True,
                normalize_embeddings=True,
            )
            return emb.to(self.device)

    def predict(self, text: str) -> CriticPrediction:
        with torch.no_grad():
            emb = self.embed(text)
            logits = self.head(emb)  # [1,4]
            probs = torch.softmax(logits, dim=-1).squeeze(0)
            conf, idx = torch.max(probs, dim=-1)
            outcome = ID_TO_OUTCOME[int(idx.item())]
            return CriticPrediction(
                outcome=outcome,
                confidence=float(conf.item()),
                probs=[float(p.item()) for p in probs],
            )

    def train_on_text_labels(
        self,
        data: List[Tuple[str, int]],
        *,
        epochs: int = 5,
        lr: float = 1e-3,
        batch_size: int = 16,
        seed: int = 0,
    ):
        """
        data: List[(text, label_id)]
        label_id in {0..3} corresponding to OUTCOMES
        """
        rng = random.Random(seed)
        opt = torch.optim.Adam(self.head.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()

        # pre-embed once (fast + stable)
        with torch.no_grad():
            embs = [self.embed(t) for (t, _) in data]
        labels = [y for (_, y) in data]

        self.head.train()
        n = len(data)
        for ep in range(epochs):
            idxs = list(range(n))
            rng.shuffle(idxs)
            total = 0.0

            for s in range(0, n, batch_size):
                batch = idxs[s : s + batch_size]
                X = torch.stack([embs[i] for i in batch], dim=0)  # [B,D]
                y = torch.tensor([labels[i] for i in batch], device=self.device)

                logits = self.head(X)  # [B,4]
                loss = loss_fn(logits, y)

                opt.zero_grad()
                loss.backward()
                opt.step()

                total += float(loss.item()) * len(batch)

            avg = total / max(1, n)
            print(f"[critic] epoch={ep+1}/{epochs} loss={avg:.4f}")

        self.head.eval()