import time
import random
import os
import logging

import litellm


class LLMInvocation:
    def __init__(self, model: str):
        self.model = model
        self.quiet = os.environ.get("TE_QUIET", "").lower() in ("1", "true", "yes")
        # Backend selection: 'litellm' (default) or 'openai'
        self.backend = os.environ.get("TE_LLM_BACKEND", "litellm").strip().lower()
        self.openai_base_url = os.environ.get("TE_OPENAI_BASE_URL")
        self.openai_stream = os.environ.get("TE_OPENAI_STREAM", "").strip().lower() in ("1", "true", "yes", "on")
        # Retry / timeout controls
        try:
            self.max_retries = int(os.environ.get("TE_LLM_MAX_RETRIES", "3"))
        except Exception:
            self.max_retries = 3
        try:
            self.base_delay = float(os.environ.get("TE_LLM_BACKOFF_BASE", "2"))
        except Exception:
            self.base_delay = 2.0
        try:
            self.request_timeout = float(os.environ.get("TE_LLM_REQUEST_TIMEOUT", "0"))  # 0 = no override
        except Exception:
            self.request_timeout = 0.0
        if self.quiet:
            # Reduce LiteLLM logger noise
            for name in ("LiteLLM", "litellm"):
                try:
                    logging.getLogger(name).setLevel(logging.ERROR)
                except Exception:
                    pass

    def call_model(self, prompt: dict, max_tokens=4096, temperature=0.2):
        """
        Returns:
            tuple: A tuple containing the response generated by the LLM,
            the number of tokens used from the prompt, and the total number of tokens in the response.
        """
        if "system" not in prompt or "user" not in prompt:
            raise KeyError("The prompt dictionary must contain 'system' and 'user' keys.")

        if prompt["system"] == "":
            messages = [{"role": "user", "content": prompt["user"]}]
        else:
            messages = [
                {"role": "system", "content": prompt["system"]},
                {"role": "user", "content": prompt["user"]},
            ]

        # Allow disabling or overriding temperature via environment
        disable_temp = os.environ.get("TE_DISABLE_TEMPERATURE", "").strip().lower() in ("1", "true", "yes", "on")
        try:
            temp_value = float(os.environ.get("TE_TEMPERATURE", str(temperature)))
        except Exception:
            temp_value = temperature

        if self.backend == "openai":
            # Native OpenAI SDK path with retries and exponential backoff
            def _openai_once():
                try:
                    from openai import OpenAI  # new SDK style
                    client_kwargs = {}
                    if self.openai_base_url:
                        client_kwargs["base_url"] = self.openai_base_url
                    # New SDK supports a default timeout via http client; skip for simplicity
                    client = OpenAI(**client_kwargs)
                    # Use Responses API for GPT-5 family, else Chat Completions
                    if str(self.model).lower().startswith("gpt-5"):
                        # Build a single input string from system+user like in user's example
                        input_text = (f"{prompt['system']}\n\n{prompt['user']}" if prompt.get('system') else prompt['user'])
                        params = {
                            "model": self.model,
                            "input": input_text,
                        }
                        if not disable_temp:
                            params["temperature"] = temp_value
                        # Prefer non-stream for reliable output_text + usage
                        response = client.responses.create(**params)
                        # New SDK exposes `output_text`
                        content = getattr(response, "output_text", None)
                        if content is None:
                            # Fallback: try to assemble from output array
                            try:
                                content = "".join([o.text for o in getattr(response, "output", []) if hasattr(o, "text")])
                            except Exception:
                                content = ""
                        usage = getattr(response, "usage", None)
                        # usage fields vary; attempt best-effort
                        pt = int(getattr(usage, "prompt_tokens", getattr(usage, "input_tokens", -1)) or -1) if usage else -1
                        ct = int(getattr(usage, "completion_tokens", getattr(usage, "output_tokens", -1)) or -1) if usage else -1
                        return (content, pt, ct)
                    else:
                        params = {
                            "model": self.model,
                            "messages": messages,
                            # stream disabled by default to surface usage reliably (can be enabled via TE_OPENAI_STREAM)
                            "stream": bool(self.openai_stream),
                            "max_tokens": max_tokens,
                        }
                        if not disable_temp:
                            params["temperature"] = temp_value
                        if params["stream"]:
                            response = client.chat.completions.create(**params)
                            content = ""
                            try:
                                for chunk in response:
                                    delta = chunk.choices[0].delta.content or ""
                                    if not self.quiet:
                                        print(delta, end="", flush=True)
                                    content += delta
                            finally:
                                if not self.quiet:
                                    print("\n")
                            # usage not available in streamed responses; return -1s
                            return (content, -1, -1)
                        else:
                            response = client.chat.completions.create(**params)
                            content = response.choices[0].message.content or ""
                            usage = getattr(response, "usage", None)
                            pt = int(getattr(usage, "prompt_tokens", -1)) if usage else -1
                            ct = int(getattr(usage, "completion_tokens", -1)) if usage else -1
                            return (content, pt, ct)
                except ImportError:
                    # Legacy SDK fallback
                    import openai as openai_legacy  # type: ignore
                    if self.openai_base_url:
                        try:
                            openai_legacy.api_base = self.openai_base_url  # type: ignore[attr-defined]
                        except Exception:
                            pass
                    params = {
                        "model": self.model,
                        "messages": messages,
                        "max_tokens": max_tokens,
                    }
                    if not disable_temp:
                        params["temperature"] = temp_value
                    resp = openai_legacy.ChatCompletion.create(**params)  # type: ignore[attr-defined]
                    content = resp["choices"][0]["message"]["content"]
                    usage = resp.get("usage", {})
                    pt = int(usage.get("prompt_tokens", -1))
                    ct = int(usage.get("completion_tokens", -1))
                    return (content, pt, ct)
            # Retry loop for OpenAI backend
            error_statement = None
            for attempt in range(self.max_retries):
                try:
                    return _openai_once()
                except Exception as e:
                    error_statement = e
                    delay = self.base_delay * (2 ** attempt) + random.uniform(0, 1)
                    if not self.quiet:
                        print(
                            f"OpenAI error; retrying in {delay:.2f}s (Attempt {attempt+1}/{self.max_retries}): {e}"
                        )
                    time.sleep(delay)
            if not self.quiet:
                print(f"OpenAI backend error: {error_statement}")
            return False, f"OpenAI backend error: {error_statement}"

        if self.model == "deepseek-r1":
            # sample input
            completion_params = {
                "model": "sagemaker/endpoint-deepseek-r1-nashid",
                "messages": messages,
                "max_tokens": max_tokens,
                "stream": True,
                "aws_region_name": "us-east-2",
            }
            if not disable_temp:
                completion_params["temperature"] = temp_value
        else:
            completion_params = {
                "model": self.model,
                "messages": messages,
                "max_tokens": max_tokens,
                "stream": True,
            }
            if not disable_temp:
                completion_params["temperature"] = temp_value
            # Optional request timeout for LiteLLM
            try:
                rt = float(self.request_timeout)
                if rt > 0:
                    completion_params["request_timeout"] = rt
            except Exception:
                pass

        error_statement = None
        for attempt in range(self.max_retries):
            try:
                response = litellm.completion(**completion_params)
                chunks = []
                try:
                    for chunk in response:
                        if not self.quiet:
                            print(chunk.choices[0].delta.content or "", end="", flush=True)
                        chunks.append(chunk)
                        time.sleep(0.01)
                except Exception as e:
                    if not self.quiet:
                        print(f"Error during streaming: {e}")
                if not self.quiet:
                    print("\n")
                model_response = litellm.stream_chunk_builder(chunks, messages=messages)
                return (
                    model_response["choices"][0]["message"]["content"],
                    int(model_response["usage"]["prompt_tokens"]),
                    int(model_response["usage"]["completion_tokens"]),
                )
            except Exception as e:
                delay = self.base_delay * (2 ** attempt) + random.uniform(0, 1)
                if not self.quiet:
                    print(f"Rate limit exceeded. "
                          f"Retrying in {delay:.2f} seconds... "
                          f"(Attempt {attempt + 1}/{self.max_retries})")
                time.sleep(delay)
                if not self.quiet:
                    print(e)
                error_statement = e

        if not self.quiet:
            print("LLM invocation failed")
        return False, f"LLM invocation failed: {error_statement}"
