from langchain_google_genai import GoogleGenerativeAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.embeddings import HuggingFaceEmbeddings
from config import GOOGLE_GEMINI_API_KEY, LLM_MODEL_NAME, EMBEDDING_MODEL_NAME
import numpy as np
import time
import traceback

# -----------------------------
# Initialize Google Gemini LLM (for final reasoning)
# -----------------------------
def get_llm(temperature=0.0, max_tokens=2048):
    # Prefer chat model for robust content extraction
    return ChatGoogleGenerativeAI(
        model=LLM_MODEL_NAME,
        google_api_key=GOOGLE_GEMINI_API_KEY,
        temperature=temperature,
        max_output_tokens=max_tokens,
    )



def get_embeddings(texts):
    """
    texts: list of strings
    returns: list of embeddings
    """
    # Filter out empty strings
    texts = [text.strip() for text in texts if text and text.strip()]
    if not texts:
        return []
    
    # Use HuggingFace embeddings (same as in rag2_retrieval.py)
    embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    embeddings = embeddings_model.embed_documents(texts)
    return embeddings

# -----------------------------
# Cosine similarity between two vectors
# -----------------------------
def cosine_similarity(vec1, vec2):
    vec1 = np.array(vec1)
    vec2 = np.array(vec2)
    if np.linalg.norm(vec1) == 0 or np.linalg.norm(vec2) == 0:
        return 0.0
    return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))

# -----------------------------
# Normalize text (basic)
# -----------------------------
def normalize_text(text):
    return text.lower().strip()

# -----------------------------
# Wrapper to call LLM for reasoning
# -----------------------------

# utils.py (add/replace these functions)



def _extract_text_from_response(resp):
    """
    Best-effort extraction of human-readable text from various SDK response shapes.
    Returns (text, meta_dict) where meta_dict may include finish_reason etc.
    """
    try:
        # If it's already a string
        if isinstance(resp, str):
            return resp, {}

        # If it's a simple object with content
        if hasattr(resp, "content") and getattr(resp, "content"):
            return getattr(resp, "content"), getattr(resp, "response_metadata", {}) or {}

        # message/content nested
        if hasattr(resp, "message"):
            msg = getattr(resp, "message")
            if isinstance(msg, str) and msg.strip():
                return msg, getattr(resp, "response_metadata", {}) or {}
            # sometimes message has .content
            if hasattr(msg, "content") and msg.content:
                return msg.content, getattr(resp, "response_metadata", {}) or {}
            # if message is dict-like
            if isinstance(msg, dict):
                txt = msg.get("content") or msg.get("text")
                if txt:
                    return txt, getattr(resp, "response_metadata", {}) or {}

        # candidates (list of candidate outputs)
        if hasattr(resp, "candidates"):
            c = getattr(resp, "candidates")
            if isinstance(c, (list, tuple)) and c:
                first = c[0]
                if isinstance(first, dict):
                    txt = first.get("content") or first.get("text")
                    if txt:
                        return txt, getattr(resp, "response_metadata", {}) or {}
                if hasattr(first, "content") and first.content:
                    return first.content, getattr(resp, "response_metadata", {}) or {}
                # fallback to str
                return str(first), getattr(resp, "response_metadata", {}) or {}

        # generations (langchain-style)
        if hasattr(resp, "generations"):
            g = getattr(resp, "generations")
            if isinstance(g, list) and len(g) > 0:
                first = g[0]
                # often first is a list of Generation objects
                if isinstance(first, list) and len(first) > 0:
                    cand = first[0]
                    txt = getattr(cand, "text", None) or getattr(cand, "message", None)
                    if txt:
                        # if message object, check .content
                        if hasattr(txt, "content"):
                            return txt.content, getattr(resp, "response_metadata", {}) or {}
                        return str(txt), getattr(resp, "response_metadata", {}) or {}
                # sometimes g is one-level list
                if hasattr(first, "text") and first.text:
                    return first.text, getattr(resp, "response_metadata", {}) or {}

        # If it's a mapping / dict-like
        if isinstance(resp, dict):
            for key in ("content", "text", "output", "message", "reasoning", "candidates"):
                if key in resp and resp[key]:
                    val = resp[key]
                    if isinstance(val, list):
                        # join content/text entries
                        pieces = []
                        for e in val:
                            if isinstance(e, dict):
                                pieces.append(e.get("content") or e.get("text") or "")
                            else:
                                pieces.append(str(e))
                        joined = "\n".join([p for p in pieces if p])
                        if joined:
                            return joined, resp.get("response_metadata", {})
                    if isinstance(val, dict):
                        txt = val.get("content") or val.get("text")
                        if txt:
                            return txt, resp.get("response_metadata", {})
                    else:
                        return str(val), resp.get("response_metadata", {})

        # As a last resort, try string conversion
        return str(resp), getattr(resp, "response_metadata", {}) or {}

    except Exception as e:
        print("[_extract_text_from_response] exception:", e)
        traceback.print_exc()
        try:
            return str(resp), getattr(resp, "response_metadata", {}) or {}
        except Exception:
            return str(resp), {}

def call_llm(prompt: str, temperature=0.0, max_tokens=2048, retry_if_truncated=True):
    """
    Robust call to LLM that:
      - requests max_tokens (passed to get_llm)
      - extracts textual output from various response shapes
      - optionally retries once with larger max_tokens if finish_reason == 'MAX_TOKENS'
    """
    text = ""
    try:
        # Create the llm client using your existing factory (keeps backward compatibility)
        llm = get_llm(temperature=temperature, max_tokens=max_tokens)

        resp = None
        # Prefer invoke if available (non-deprecated)
        if hasattr(llm, "invoke"):
            try:
                resp = llm.invoke(prompt)
            except TypeError:
                # some versions expect list or different signature
                resp = llm.invoke([prompt])
        elif hasattr(llm, "generate"):
            try:
                resp = llm.generate([prompt])
            except TypeError:
                resp = llm.generate(prompt)
        elif hasattr(llm, "predict"):
            resp = llm.predict(prompt)
        else:
            # fallback to calling object
            resp = llm(prompt)

        text, meta = _extract_text_from_response(resp)

        # If the SDK included response metadata and it indicates truncation, handle it
        finish_reason = None
        if isinstance(meta, dict):
            finish_reason = meta.get("finish_reason") or meta.get("stop_reason") or meta.get("finish_reason_details")

        # In some SDKs response may embed response_metadata on the top-level resp itself
        if not finish_reason:
            # try to pull response_metadata directly
            try:
                rm = getattr(resp, "response_metadata", None) or (resp.get("response_metadata") if isinstance(resp, dict) else None)
                if isinstance(rm, dict):
                    finish_reason = rm.get("finish_reason")
            except Exception:
                pass

        # If truncated and we allowed retry, do one retry with a bigger token budget
        if finish_reason == "MAX_TOKENS" and retry_if_truncated:
            print("[call_llm] WARNING: response finished due to MAX_TOKENS (truncated). Retrying with larger max_tokens...")
            # Simple exponential bump (cap to reasonable number)
            new_max = min(8192, max_tokens * 2)
            if new_max > max_tokens:
                time.sleep(0.3)
                return call_llm(prompt, temperature=temperature, max_tokens=new_max, retry_if_truncated=False)

    except Exception as e:
        print("[call_llm] Exception while calling LLM:", e)
        traceback.print_exc()
        text = ""

    # Deterministic fallback if nothing obtained
    if not text or not str(text).strip():
        fallback = (
            "Marks Awarded: 0\n\n"
            "Reasoning: Unable to generate model rationale at this time. "
            "Check logs printed above; consider increasing max_tokens or inspect response repr."
        )
        print("[call_llm] WARNING: returning fallback text. See earlier logs for details.")
        return fallback

    return str(text).strip()
