"""Gemini API runner for CIRBench (new Google GenAI SDK).

Uses the unified `google-genai` SDK instead of the deprecated
`google-generativeai` package. Provides a text-only completion interface,
normalized metadata (tokens, latency, finish reason), dynamic output-token
budgeting with optional retries, and full logging hooks.

References:
  * SDK deprecation notice: https://pypi.org/project/google-generativeai/
  * Migration & API guide (unified client, generate/count tokens, config):
    https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/migrate-genai-sdk
"""
from __future__ import annotations
import os, time, json, traceback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from ..logging_utils import get_logger, debug_on, debug_full_on

from google import genai
from google.genai import types


@dataclass
class Completion:
    text: str
    meta: dict


# --- Helpers --------------------------------------------------------------

def _resp_to_dict(resp) -> dict:
    """Best-effort conversion of a response object to a plain dict (for logs)."""
    try:
        # Pydantic models in google-genai expose `.model_dump()` and `.to_dict()`
        if hasattr(resp, "model_dump"):
            return resp.model_dump()
        if hasattr(resp, "to_dict"):
            return resp.to_dict()
    except Exception:
        pass
    # Fallback: probe common fields
    d: Dict[str, Any] = {}
    try:
        cands = getattr(resp, "candidates", None) or []
        d["candidates_len"] = len(cands)
        if cands:
            fr = getattr(cands[0], "finish_reason", None)
            d["finish_reason"] = getattr(fr, "name", None) or str(fr)
    except Exception:
        pass
    try:
        um = getattr(resp, "usage_metadata", None)
        if um:
            d["usage"] = {
                "prompt": getattr(um, "prompt_token_count", None),
                "candidates": getattr(um, "candidates_token_count", None),
                "total": getattr(um, "total_token_count", None),
            }
    except Exception:
        pass
    return d or {"repr": repr(resp)}


def _safe_text(resp) -> str:
    """Extract plain text from a GenAI response."""
    # Many responses provide `.text` directly
    try:
        t = getattr(resp, "text", None)
        if isinstance(t, str) and t:
            return t
    except Exception:
        pass
    # Otherwise, aggregate from candidates/parts
    texts: List[str] = []
    try:
        for cand in getattr(resp, "candidates", []) or []:
            content = getattr(cand, "content", None)
            if not content:
                continue
            for part in getattr(content, "parts", []) or []:
                pt = getattr(part, "text", None)
                if isinstance(pt, str) and pt:
                    texts.append(pt)
                else:
                    try:
                        texts.append(str(part))
                    except Exception:
                        pass
    except Exception:
        return ""
    return "\n".join(texts)


def _usage(resp) -> dict:
    """Compact usage/finish metadata from response (fields vary by SDK version)."""
    d: Dict[str, Any] = {}
    try:
        cand0 = (getattr(resp, "candidates", []) or [None])[0]
        fr = getattr(cand0, "finish_reason", None)
        d["finish_reason"] = getattr(fr, "name", None) or str(fr)
    except Exception:
        pass
    try:
        um = getattr(resp, "usage_metadata", None)
        if um:
            d["prompt_tokens"] = getattr(um, "prompt_token_count", None) or getattr(um, "input_token_count", None)
            d["out_tokens"] = getattr(um, "candidates_token_count", None) or getattr(um, "output_token_count", None)
            d["total_tokens"] = getattr(um, "total_token_count", None)
    except Exception:
        pass
    return d


def _full_log_path() -> str:
    """Resolve per-shard full log path.

    Priority:
      - runs/<RUN_ID>/full.shard-<SID>.log if CIRBENCH_RUN_ID is set
      - <LOG_DIR>/full.shard-<SID>.log if CIRBENCH_LOG_DIR is set
      - ./full.shard-<SID>.log otherwise
    """
    run_id = os.getenv("CIRBENCH_RUN_ID")
    log_dir = os.getenv("CIRBENCH_LOG_DIR")
    shard_id = os.getenv("CIRBENCH_SHARD_ID", "1")
    fname = f"full.shard-{shard_id}.log"
    if run_id:
        return os.path.join("runs", run_id, fname)
    if log_dir:
        return os.path.join(log_dir, fname)
    return fname


def _append_full(stage: str, payload: str, meta: Optional[dict] = None):
    """Append a stage-tagged payload to the per-shard full log.

    Disabled by default. Set CIRBENCH_ENABLE_FULL=1 to enable.
    """
    if os.getenv("CIRBENCH_ENABLE_FULL", "0") != "1":
        return
    path = _full_log_path()
    try:
        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        with open(path, "a", encoding="utf-8") as f:
            f.write(f"[FULL] {stage}\n")
            if meta:
                try:
                    f.write(f"META: {json.dumps(meta, ensure_ascii=False)}\n")
                except Exception:
                    f.write(f"META: {str(meta)}\n")
            f.write(payload if isinstance(payload, str) else str(payload))
            f.write("\n")
    except Exception:
        pass


# --- Exception diagnostics helper -----------------------------------------

def _exc_diag(ex) -> dict:
    """Best-effort extraction of diagnostic info from exceptions raised by the SDK/HTTP layer."""
    d: Dict[str, Any] = {
        "type": ex.__class__.__name__,
        "msg": str(ex),
    }
    # Common attributes seen across httpx/requests/google errors
    for attr in ("status_code", "reason", "code", "errors"):
        try:
            v = getattr(ex, attr, None)
            if v is not None and v != "":
                d[attr] = int(v) if isinstance(v, int) else str(v)
        except Exception:
            pass
    # Try to pull HTTP response info if present
    resp = getattr(ex, "response", None)
    try:
        if resp is not None:
            try:
                d["response_status"] = getattr(resp, "status_code", None) or getattr(resp, "status", None)
            except Exception:
                pass
            try:
                body = getattr(resp, "text", None) or getattr(resp, "content", None)
                if body:
                    bs = body if isinstance(body, str) else body.decode("utf-8", "ignore")
                    d["response_body"] = bs[:2048]
            except Exception:
                pass
            try:
                hdrs = getattr(resp, "headers", None)
                if hdrs:
                    # Log a few safe headers
                    for hk in ("x-goog-quota-project", "x-quote-reason", "x-error-code"):
                        if hk in hdrs:
                            d.setdefault("response_headers", {})[hk] = str(hdrs.get(hk))
            except Exception:
                pass
    except Exception:
        pass
    # Minimal traceback (exception-only) for quick triage
    try:
        d["trace"] = "".join(traceback.format_exception_only(type(ex), ex)).strip()
    except Exception:
        pass
    return d


# --- Runner ---------------------------------------------------------------
class GeminiRunner:
    """Thin wrapper around the new `google-genai` Client API."""

    def __init__(self, params: Dict[str, Any]):
        self.params = params or {}
        api_key = self.params.get("api_key") or os.getenv("GEMINI_API_KEY")
        if not api_key:
            raise RuntimeError("GEMINI_API_KEY not set and no api_key in params.")
        if genai is None:
            raise RuntimeError(
                "google-genai not installed. Run: pip install -U google-genai"
            )

        # Developer API by default; users on Vertex AI can set envs or pass kwargs
        # (see migration guide for vertexai=True / project/location envs).
        self.client = genai.Client(api_key=api_key)

        # Model name/id (e.g., "gemini-2.5-flash" or "models/gemini-1.5-pro")
        self.model = self.params.get("model", "gemini-2.5-flash")

        # Base generation config; per-call we inject a dynamic max_output_tokens
        self.base_config: Dict[str, Any] = {
            "temperature": self.params.get("temperature", 0.1),
            "top_p": self.params.get("top_p", 0.9),
            # "max_output_tokens" is computed dynamically per prompt
        }
        self.add_stop_ir_out: bool = bool(self.params.get("add_stop_ir_out", True))

        # --- Thinking (chain-of-thought) control ---------------------------------
        # Default OFF to avoid blowing token budgets. Can be overridden via params["thinking"].
        # Accepts: False | True | int budget | {"budget": int, "include_thoughts": bool}
        self.thinking_cfg = None
        try:
            tcfg = self.params.get("thinking", None)
            if tcfg is None:
                # Explicitly disable thinking by default (budget=0)
                self.thinking_cfg = types.ThinkingConfig(thinking_budget=0, include_thoughts=False)
            elif isinstance(tcfg, bool):
                if tcfg:
                    # Enable with dynamic budgeting (-1 lets the server pick), hide thoughts by default
                    self.thinking_cfg = types.ThinkingConfig(thinking_budget=-1, include_thoughts=bool(self.params.get("include_thoughts", False)))
                else:
                    self.thinking_cfg = types.ThinkingConfig(thinking_budget=0, include_thoughts=False)
            elif isinstance(tcfg, int):
                self.thinking_cfg = types.ThinkingConfig(thinking_budget=int(tcfg), include_thoughts=bool(self.params.get("include_thoughts", False)))
            elif isinstance(tcfg, dict):
                budget = tcfg.get("budget", tcfg.get("thinking_budget", 0))
                include = bool(tcfg.get("include_thoughts", self.params.get("include_thoughts", False)))
                # -1 means dynamic budget per Google GenAI docs; 0 disables
                try:
                    budget = int(budget)
                except Exception:
                    budget = 0
                self.thinking_cfg = types.ThinkingConfig(thinking_budget=budget, include_thoughts=include)
        except Exception:
            # On any SDK/import issue, fall back to disabled thinking
            try:
                self.thinking_cfg = types.ThinkingConfig(thinking_budget=0, include_thoughts=False)
            except Exception:
                self.thinking_cfg = None

    # --- token budgeting --------------------------------------------------
    def _get_context_window(self) -> int:
        """Best‑effort context window size for budgeting output tokens.
        Priority: params.context_window -> env GEMINI_CONTEXT_WINDOW -> default 32768.
        """
        try:
            return int(self.params.get("context_window") or int(os.getenv("GEMINI_CONTEXT_WINDOW", "32768")))
        except Exception:
            return 32768

    def _count_prompt_tokens(self, text: str) -> int:
        """Ask the SDK to count tokens (client.models.count_tokens); fallback to char/4."""
        try:
            # The unified SDK accepts raw strings or a list of contents
            ct = self.client.models.count_tokens(model=self.model, contents=[text])
            # Try common field names
            for k in ("total_tokens", "token_count", "total_token_count"):
                v = getattr(ct, k, None) if not isinstance(ct, dict) else ct.get(k)
                if isinstance(v, int):
                    return v
        except Exception:
            pass
        return max(1, len(text) // 4)

    def _compute_max_out(self, prompt: str) -> int:
        """Compute a per‑prompt max_output_tokens within the model window minus a safety margin.
        Respects user cap via params.max_output_tokens (as an upper bound).
        """
        window = self._get_context_window()
        safety = int(self.params.get("window_safety", 512))
        base_cap = int(self.params.get("max_output_tokens", 16384))
        ptok = self._count_prompt_tokens(prompt)
        budget = max(128, window - ptok - safety)
        return max(128, min(base_cap, budget))

    # --- public API -------------------------------------------------------
    def generate(self, prompts: List[str], **decode) -> List[Completion]:
        outs: List[Completion] = []
        logger = get_logger()
        for p in prompts:
            t0 = time.time()
            _append_full("ASK", p, None)

            dyn_max = self._compute_max_out(p)
            # Extra diagnostics for budgeting and environment
            try:
                ptok_dbg = self._count_prompt_tokens(p)
            except Exception:
                ptok_dbg = None
            diag_meta = {"model": self.model, "prompt_tokens_est": ptok_dbg, "dyn_max_out": dyn_max}
            _append_full("SEND_CFG", json.dumps({"generation_config": {**self.base_config, **decode, "max_output_tokens": dyn_max}}, ensure_ascii=False), diag_meta)

            cfg_dict = {**self.base_config, **decode, "max_output_tokens": dyn_max}
            # Pass thinking configuration (default disabled unless overridden)
            if self.thinking_cfg is not None:
                cfg_dict["thinking_config"] = self.thinking_cfg
            # Helpful early stop for our IR tasks unless user overrides
            if "stop_sequences" not in cfg_dict and self.add_stop_ir_out:
                cfg_dict["stop_sequences"] = ["</IR_OUT>","</CIR_JSON>"]

            # Build typed config when available; fall back to plain dict
            try:
                req_config = types.GenerateContentConfig(**cfg_dict) if types else cfg_dict
            except Exception:
                req_config = cfg_dict

            logger.info(f"DEBUG: Sending generation_config: {cfg_dict}")
            try:
                tb = getattr(getattr(self.thinking_cfg, "thinking_budget", None), "value", None)
            except Exception:
                tb = None
            logger.debug(f"DEBUG: thinking_budget={getattr(self.thinking_cfg, 'thinking_budget', None)} include_thoughts={getattr(self.thinking_cfg, 'include_thoughts', None)}")
            try:
                try:
                    resp = self.client.models.generate_content(
                        model=self.model,
                        contents=p,
                        config=req_config,
                    )
                except TypeError:
                    # Older previews used `generation_config=` – try that for safety
                    resp = self.client.models.generate_content(
                        model=self.model,
                        contents=p,
                        generation_config=req_config,  # type: ignore[arg-type]
                    )
            except Exception as ex:
                t1 = time.time()
                diag = _exc_diag(ex)
                diag_meta_err = {"latency_ms": int((t1 - t0) * 1000), "stage": "first_call"}
                _append_full("ERR", json.dumps(diag, ensure_ascii=False), diag_meta_err)
                logger = get_logger()
                try:
                    logger.error(f"Gemini generate_content error (first_call): {diag}")
                except Exception:
                    pass
                raise
            t1 = time.time()

            meta = _usage(resp)
            meta["latency_ms"] = int((t1 - t0) * 1000)

            resp_dict = _resp_to_dict(resp)
            try:
                _append_full("RECV_RAW", json.dumps(resp_dict, ensure_ascii=False), meta)
            except Exception:
                _append_full("RECV_RAW", str(resp_dict), meta)

            text = _safe_text(resp)
            _append_full("RECV_TEXT", text, meta)

            # If we hit MAX_TOKENS with empty/partial output, retry once with a smaller cap
            fr = (meta or {}).get("finish_reason")
            if (str(fr).upper() in {"MAX_TOKENS", "LENGTH"} or (not text.strip() and fr)) and self.params.get("retry_on_maxtokens", True):
                if dyn_max > 256:
                    dyn_max = max(256, dyn_max // 2)
                    cfg2 = {**self.base_config, **decode, "max_output_tokens": dyn_max}
                    if "stop_sequences" not in cfg2 and self.add_stop_ir_out:
                        cfg2["stop_sequences"] = ["</IR_OUT>"]
                    try:
                        req2 = types.GenerateContentConfig(**cfg2) if types else cfg2
                    except Exception:
                        req2 = cfg2
                    _append_full("RETRY_CFG", json.dumps(cfg2, ensure_ascii=False), None)
                    _append_full("RETRY_CFG_VERBOSE", json.dumps({"generation_config": cfg2}, ensure_ascii=False), {"dyn_max_out": dyn_max})
                    t0b = time.time()
                    try:
                        try:
                            resp = self.client.models.generate_content(
                                model=self.model,
                                contents=p,
                                config=req2,
                            )
                        except TypeError:
                            resp = self.client.models.generate_content(
                                model=self.model,
                                contents=p,
                                generation_config=req2,  # type: ignore[arg-type]
                            )
                    except Exception as ex:
                        t1b = time.time()
                        diag = _exc_diag(ex)
                        diag_meta_err = {"latency_ms": int((t1b - t0b) * 1000), "stage": "retry_call"}
                        _append_full("ERR", json.dumps(diag, ensure_ascii=False), diag_meta_err)
                        logger = get_logger()
                        try:
                            logger.error(f"Gemini generate_content error (retry_call): {diag}")
                        except Exception:
                            pass
                        raise
                    t1b = time.time()
                    m2 = _usage(resp)
                    m2["latency_ms"] = int((t1b - t0b) * 1000) + int(meta.get("latency_ms") or 0)
                    meta = m2
                    resp_dict = _resp_to_dict(resp)
                    try:
                        _append_full("RECV_RAW", json.dumps(resp_dict, ensure_ascii=False), meta)
                    except Exception:
                        _append_full("RECV_RAW", str(resp_dict), meta)
                    text = _safe_text(resp)
                    _append_full("RECV_TEXT", text, meta)

            if debug_full_on():
                logger.info(f"RECV_RAW[{meta.get('finish_reason','?')} tokens_out={meta.get('out_tokens')}]\n{text}")
                if (not text.strip()) or not meta.get("out_tokens"):
                    try:
                        pretty = json.dumps(resp_dict, ensure_ascii=False, indent=2)
                    except Exception:
                        pretty = str(resp_dict)
                    logger.warning("Gemini returned no text. Diagnostics:\n" + pretty)

            if debug_on() and not debug_full_on():
                try:
                    fr = meta.get("finish_reason")
                    usage = f" fr={fr} prompt={meta.get('prompt_tokens','')} out={meta.get('out_tokens','')}"
                    logger.info(f"LLM[gemini] meta:{usage}")
                except Exception:
                    pass

            outs.append(Completion(text=text, meta=meta))
            time.sleep(self.params.get("sleep_sec", 0.0))
        return outs