# inference.py
# End-to-end inference pipeline wiring Phase I, router, Phase II (constrained decoding), compliance checks.
# Minimal example wiring with placeholders for real models.

from typing import Optional, Dict, Any, Tuple
import torch
from config import DEVICE
from phase1_parsing import PhaseIParser
from audio_encoder import AudioEncoder
from router import Router
from constrained_decoder import ConstrainedDecoder, hf_logprob_fn_from_model
from templates import get_templates, ConstraintSet
from compliance_checks import run_compliance, escalate
from metaphor_detector import MetaphorDetector

# For demonstration, we provide a small text-embedder fallback (sentence-transformers if available)
try:
    from sentence_transformers import SentenceTransformer
except Exception:
    SentenceTransformer = None

class InferenceService:
    """
    Orchestrates the pipeline:
      text/audio/meta -> Phase I parser -> Router -> choose pathway ->
        - prompt-only: format template + LM.generate (no constraint)
        - constrained: use ConstrainedDecoder
        - clinician: escalate/human
      then compliance check, possible escalation
    """
    def __init__(self, device: Optional[str] = None):
        self.device = device or DEVICE
        # load or instantiate components
        self.text_embedder = SentenceTransformer("all-MiniLM-L6-v2") if SentenceTransformer is not None else None
        # instantiate Phase I parser (with default dims)
        self.phase1 = PhaseIParser()
        self.audio_encoder = AudioEncoder(device=self.device)
        self.router = Router()
        self.metaphor_detector = MetaphorDetector()
        # example constraint set and decoder; base_logprob_fn must be hooked to a real LM
        self.constraint_set = ConstraintSet()
        # for demo create dummy base_logprob_fn returning uniform logprobs
        def uniform_logprob(prefix):
            V = 20000
            return np.log(np.ones(V) / V)
        import numpy as np
        self.decoder = ConstrainedDecoder(base_logprob_fn=uniform_logprob, vocab=None, constraint_set=self.constraint_set)
        # optionally, integrate a HF model here:
        # from transformers import GPT2TokenizerFast, GPT2LMHeadModel
        # tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        # model = GPT2LMHeadModel.from_pretrained("gpt2").to(self.device)
        # self.decoder = ConstrainedDecoder(base_logprob_fn=hf_logprob_fn_from_model(tokenizer, model), vocab=tokenizer.get_vocab_list(), constraint_set=self.constraint_set)

    def _embed_text(self, text: str):
        if self.text_embedder is not None:
            return torch.from_numpy(self.text_embedder.encode([text])[0]).float()
        # fallback: random vector
        return torch.randn(768)

    def handle_request(self, text: str, audio_waveform: Optional[torch.Tensor] = None, meta: Optional[Dict] = None) -> Tuple[str, bool, Dict]:
        """
        Main API.
        Returns: (response_text, escalated_flag, debug_info)
        """
        debug = {}
        # 1) compute text embedding
        text_embed = self._embed_text(text)
        debug["text_embed_shape"] = tuple(text_embed.shape)

        # 2) compute hlex/hmeta
        from preprocess_text import compute_hlex_hmeta
        hm = compute_hlex_hmeta(text)
        hlex, hmeta = hm["hlex"], hm["hmeta"]
        debug["hlex"] = hlex.tolist() if hasattr(hlex, "tolist") else hlex
        debug["hmeta"] = hmeta.tolist() if hasattr(hmeta, "tolist") else hmeta

        # 3) Phase I parsing
        c = self.phase1.parse(text_embed, hlex, hmeta, meta_tensor=torch.zeros(8))
        debug["phase1"] = c

        # 4) metaphor detection
        m_score = float(self.metaphor_detector(text_embed.unsqueeze(0))) if hasattr(self.metaphor_detector, "__call__") else 0.0
        debug["metaphor_score"] = m_score

        # 5) routing decision
        pathway, probs = self.router.route(c)
        debug["pathway"] = pathway
        debug["route_probs"] = probs.tolist() if hasattr(probs, "tolist") else probs

        # 6) Phase II: produce response depending on pathway
        if pathway == "clinician":
            # trigger escalation
            response = "Your case has been escalated to a clinician for review."
            escalated = True
            debug["phase2"] = "escalated_to_clinician"
        elif pathway == "prompt_only":
            # use template only or a lightweight LM generation (here we use template)
            templates = get_templates("low", c["n"])
            response = templates[0]
            escalated = False
            debug["phase2"] = "template_response"
        else:
            # constrained generation
            # For demo, prefix is empty token list
            prefix = []
            seq, score = self.decoder.constrained_generate(prefix, max_len=40, beam_size=4, lambda_penalty=8.0)
            # convert seq to text (decoder._tokens_to_text)
            response = self.decoder._tokens_to_text(seq)
            escalated = False
            debug["phase2"] = {"seq_len": len(seq), "score": score}

        # 7) compliance / audit checks
        passed, reasons = run_compliance(response, constraint_set=self.constraint_set)
        debug["compliance_passed"] = passed
        debug["compliance_reasons"] = reasons
        if not passed:
            escalated = True
            escalate(response, reasons)

        return response, escalated, debug


# convenience function wrapper
def handle_request(text: str, audio=None, meta=None):
    svc = InferenceService()
    return svc.handle_request(text, audio, meta)
