from __future__ import annotations
from pathlib import Path

import boto3
import json
import os
from dataclasses import dataclass, field
from typing import Any

from botocore.config import Config
from openai import OpenAI
import yaml
import copy
from google import genai  # type: ignore


@dataclass
class UsageTracker:
    """
    Accumulates token usage from OpenAI-compatible responses.

    We keep this intentionally simple and JSON-serializable so it can be dumped
    into outputs/behavior_*.json for cost estimation.
    """

    total: dict[str, int] = field(
        default_factory=lambda: {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
            # If prompt caching is enabled/supported, OpenAI may return cached tokens.
            # We store it for cost diagnostics (cached input is typically discounted).
            "cached_tokens": 0,
        }
    )
    by_provider: dict[str, dict[str, int]] = field(default_factory=dict)
    by_model: dict[str, dict[str, int]] = field(default_factory=dict)
    requests: int = 0

    def _add_bucket(self, bucket: dict[str, dict[str, int]], key: str, usage: dict[str, int]) -> None:
        if key not in bucket:
            bucket[key] = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "cached_tokens": 0}
        for k in ("prompt_tokens", "completion_tokens", "total_tokens", "cached_tokens"):
            bucket[key][k] += int(usage.get(k, 0) or 0)

    def add(self, provider: str, model: str, usage: dict[str, int] | None) -> None:
        if not usage:
            return
        self.requests += 1
        for k in ("prompt_tokens", "completion_tokens", "total_tokens", "cached_tokens"):
            self.total[k] += int(usage.get(k, 0) or 0)
        self._add_bucket(self.by_provider, provider, usage)
        self._add_bucket(self.by_model, model, usage)

    def snapshot(self) -> dict[str, Any]:
        return {
            "requests": self.requests,
            "total": dict(self.total),
            "by_provider": {k: dict(v) for k, v in self.by_provider.items()},
            "by_model": {k: dict(v) for k, v in self.by_model.items()},
        }

class Conversation:
    """
    Tracks the conversation between user and assistant. System prompt (if any)
    is stored separately from the conversation history. Conversation can be
    serialized in below format:
    ```
    [
        {"role": "system", "content": "..."}
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."},
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."},
        ...
    ]
    ```
    """

    def __init__(self, system_prompt: str | None = None) -> None:
        self.history = []
        self.system_prompt = system_prompt

    def add_user_message(self, message: str) -> None:
        self.history.append({"role": "user", "content": message})

    def add_assistant_message(self, message: str) -> None:
        self.history.append({"role": "assistant", "content": message})

    def serialize(self, include_system=True) -> list[dict]:
        copy = self.history.copy()
        if include_system and self.system_prompt:
            copy.insert(0, {"role": "system", "content": self.system_prompt})
        return copy

    def __str__(self, include_system=True) -> str:
        if include_system and self.system_prompt:
            all_history = [
                {"role": "system", "content": self.system_prompt}
            ] + self.history
            return json.dumps(all_history, indent=2)
        return json.dumps(self.history, indent=2)
    
class UnifiedLLMClient:
    """
    A unified client for invoking LLM from different providers.

    This client:
     - only supports text input for now (no multimodal support)
     - only supports models listed in `supported_models.yml`

    Refer to pricing before choosing model:
     - OpenAI: https://platform.openai.com/docs/pricing
     - Azure: https://azure.microsoft.com/en-us/pricing/details/ai-foundry-models/model-router/
     - AWS: https://aws.amazon.com/bedrock/pricing/

    You need to configure environment variables for different providers:
     - OpenAI: `OPENAI_API_KEY`
     - AWS: `AWS_DEFAULT_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`

    Sample usage:
     ```python
     # for online provider
     client = UnifiedLLMClient("gpt-5-mini")
     # for local deployment
     client = UnifiedLLMClient("local", base_url="http://localhost:30000/v1/")
     # single-turn (don't need conversation tracking)
     output, _ = client.generate("Tell me a joke")
     # multi-turn (with conversation tracking)
     output_1, conv = client.generate("Tell me a joke")
     output_2, conv = client.generate("Now explain why it's funny", conversation=conv)
     ```
    """

    Conversation = Conversation

    # load supported models from supported_models.yml
    SUPPORTED_MODELS = {
        m["name"]: m["providers"]
        for m in yaml.safe_load(
            open(Path(__file__).parent / "supported_models.yml")
        ).get("supported_models", [])
    }

    def __init__(
        self, model: str, provider: str | None = None, base_url: str | None = None
    ) -> None:
        """
        Initialize an LLM client. Can optionally specify which provider to use.
        Supported (model, provider) combinations are listed in supported_models.yml

        Args:
         - model (required): The model to use
         - provider (optional): A specific provider to use. Omit to use default
        """
        # validate model / provider
                # validate model
        # validate model / provider
        if model not in UnifiedLLMClient.SUPPORTED_MODELS:
            raise ValueError(f"UnifiedLLMClient doesn't support {model} yet")

        # validate / choose provider
        self.provider = provider
        if provider is None:
            self.provider = next(iter(UnifiedLLMClient.SUPPORTED_MODELS[model]))
        elif provider not in UnifiedLLMClient.SUPPORTED_MODELS[model].keys():
            raise ValueError(
                f"UnifiedLLMClient doesn't support {model} from {provider} yet"
            )
        # get provider-specific model id
        self.model_id = UnifiedLLMClient.SUPPORTED_MODELS[model][self.provider]
        self.model = model  # Store original model name for checking GPT-5
        self.usage_tracker = UsageTracker()

        # initialize client based on provider choice
        if self.provider == "openai":
            if "OPENAI_API_KEY" not in os.environ:
                raise EnvironmentError("Please configure OPENAI_API_KEY as an environment variable")
            self.client = OpenAI()

        elif self.provider == "gemini":
            if genai is None:
                raise ImportError(
                    "Gemini provider requested, but dependency `google-genai` is not available. "
                    "Install it (e.g., `pip install google-genai`) or use provider='openai'/'aws'/'sglang'."
                )
            api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
            if not api_key:
                raise EnvironmentError("Please configure GEMINI_API_KEY or GOOGLE_API_KEY as an environment variable")
            self.client = genai.Client(api_key=api_key)

        elif self.provider in ("sglang", "local"):
            if base_url is None:
                raise ValueError(
                    "Please specify base_url (e.g. http://localhost:30000/v1) to use SGLang"
                )
            self.client = OpenAI(base_url=base_url, api_key="")

        elif self.provider == "aws":
            if "AWS_ACCESS_KEY_ID" not in os.environ:
                raise EnvironmentError("Please configure AWS_ACCESS_KEY_ID as an environment variable")
            if "AWS_SECRET_ACCESS_KEY" not in os.environ:
                raise EnvironmentError("Please configure AWS_SECRET_ACCESS_KEY as an environment variable")
            if "AWS_DEFAULT_REGION" not in os.environ:
                raise EnvironmentError("Please configure AWS_DEFAULT_REGION as an environment variable")
            # disable boto3's automatic retries
            config = Config(retries={"max_attempts": 1, "mode": "standard"})
            self.client = boto3.client("bedrock-runtime", config=config)

        else:
            raise ValueError(f"Invalid provider {self.provider}")

    @staticmethod
    def _format_messages_as_text(messages: list[dict], system_prompt: str | None) -> str:
        """
        Best-effort flattening of chat messages into a single text prompt for providers
        that don't take OpenAI-style role messages (e.g., Gemini via google-genai).
        """
        parts: list[str] = []
        if system_prompt:
            parts.append(f"System: {system_prompt}")
        for m in messages:
            role = str(m.get("role", "")).strip() or "user"
            content = m.get("content", "")
            parts.append(f"{role.capitalize()}: {content}")
        return "\n".join(parts).strip()

    @staticmethod
    def _extract_openai_usage(response: Any) -> dict[str, int] | None:
        """
        Best-effort extraction of token usage from OpenAI python client responses.

        Expected shapes:
        - response.usage.prompt_tokens / completion_tokens / total_tokens
        - response["usage"]["prompt_tokens"] ... (dict-like)
        """
        usage = None
        try:
            usage = getattr(response, "usage", None)
        except Exception:
            usage = None

        if usage is None:
            # dict-like fallback
            try:
                usage = response.get("usage")  # type: ignore[attr-defined]
            except Exception:
                usage = None

        if usage is None:
            return None

        def _get(obj: Any, key: str) -> int | None:
            try:
                if isinstance(obj, dict):
                    return obj.get(key)
                return getattr(obj, key, None)
            except Exception:
                return None

        prompt_tokens = _get(usage, "prompt_tokens")
        completion_tokens = _get(usage, "completion_tokens")
        total_tokens = _get(usage, "total_tokens")
        # Prompt caching:
        # - Some API shapes return usage.prompt_tokens_details.cached_tokens
        # - Some may return usage.cached_tokens directly
        cached_tokens = None
        try:
            prompt_details = _get(usage, "prompt_tokens_details")  # type: ignore[arg-type]
            if prompt_details is not None:
                if isinstance(prompt_details, dict):
                    cached_tokens = prompt_details.get("cached_tokens")
                else:
                    cached_tokens = getattr(prompt_details, "cached_tokens", None)
        except Exception:
            cached_tokens = None
        if cached_tokens is None:
            cached_tokens = _get(usage, "cached_tokens")

        # Some providers may only provide total_tokens.
        out: dict[str, int] = {}
        if prompt_tokens is not None:
            out["prompt_tokens"] = int(prompt_tokens)
        if completion_tokens is not None:
            out["completion_tokens"] = int(completion_tokens)
        if total_tokens is not None:
            out["total_tokens"] = int(total_tokens)
        if cached_tokens is not None:
            out["cached_tokens"] = int(cached_tokens)
        if not out:
            return None
        # normalize: if total missing but prompt+completion present
        if "total_tokens" not in out and "prompt_tokens" in out and "completion_tokens" in out:
            out["total_tokens"] = out["prompt_tokens"] + out["completion_tokens"]
        return out

    def get_usage_snapshot(self) -> dict[str, Any]:
        return self.usage_tracker.snapshot()

    def generate(
        self,
        user_input,
        conversation: UnifiedLLMClient.Conversation | None = None,
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        max_completion_tokens: int | None = None,
        reasoning_effort: str | None = None,
        response_format: dict | None = None,
        top_p: float | None = None,
        n: int | None = None,
        seed: int | None = None,
        extra_params: dict | None = None,
    ):
        """
        Generate a response based on the user input and optional parameters.

        `user_input` can be:
        - str or textgrad Variable-like (has `.value`): we manage conversation and return (text, Conversation)
        - list of messages (dicts with role/content): we send raw messages and return text only
        """
        # # If raw messages are provided, bypass conversation handling.
        # if isinstance(user_input, list):
        #     return self._chat_from_messages(
        #         user_input,
        #         temperature=temperature,
        #         max_tokens=max_tokens,
        #         response_format=response_format,
        #         extra_params=extra_params,
        #     )

        # Normalize user_input to string (support textgrad Variable with .value)
        if hasattr(user_input, "value"):
            user_text = user_input.value
        else:
            user_text = str(user_input)

        if isinstance(conversation, list):
            # legacy support for list of messages as conversation
            new_conversation = UnifiedLLMClient.Conversation()
            for msg in conversation:
                if msg["role"] == "system":
                    new_conversation.system_prompt = msg["content"]
                elif msg["role"] == "user":
                    new_conversation.add_user_message(msg["content"])
                elif msg["role"] == "assistant":
                    new_conversation.add_assistant_message(msg["content"])
        elif isinstance(conversation, UnifiedLLMClient.Conversation):
            new_conversation = copy.deepcopy(conversation)
        elif conversation is None:
            new_conversation = UnifiedLLMClient.Conversation(system_prompt=system_prompt)
        else:
            raise ValueError("conversation must be None, list of messages, or Conversation object")
        new_conversation.add_user_message(user_text)

        # Match upstream behavior: if conversation has no system prompt (None),
        # fill it with the provided system_prompt; otherwise require exact match.
        if system_prompt is not None and new_conversation.system_prompt is None:
            new_conversation.system_prompt = system_prompt
        # check if the provided system prompt matches with conversation system prompt
        if system_prompt is not None and new_conversation.system_prompt != system_prompt:
            raise ValueError(
                f"System prompt does not match the one in the conversation.\n{system_prompt}\n{new_conversation.system_prompt}"
            )

        # use system prompt from conversation (even if not provided explicitly)
        system_prompt = new_conversation.system_prompt

        # Check if model is GPT-5 or gpt-5-mini series
        # This includes: gpt-5, gpt-5-mini, and their date variants (e.g., gpt-5-2025-08-07, gpt-5-mini-2025-08-07)
        # But excludes gpt-5.1, gpt-5.2, etc.
        model_lower = self.model.lower()
        is_gpt5_series = (
            model_lower == "gpt-5" or 
            model_lower == "gpt-5-mini" or
            (model_lower.startswith("gpt-5-") and not model_lower.startswith("gpt-5.")) or 
            model_lower.startswith("gpt-5-mini-")
        )

        # For GPT-5 or gpt-5-mini: default reasoning_effort to "minimal" and skip temperature
        if is_gpt5_series:
            if reasoning_effort is None:
                reasoning_effort = "minimal"
            # Skip temperature for GPT-5 series (they don't support custom temperature)
            temperature = None

        # Set deterministic defaults if not provided
        # For determinism: top_p=1 (consider all tokens), n=1 (single response), seed=123 (default seed)
        # Check environment variable to allow disabling default seed
        enable_determinism = os.getenv("ENABLE_DETERMINISM", "true").lower() == "true"
        
        if top_p is None:
            top_p = 1
        if n is None:
            n = 1
        if seed is None and self.provider in ("openai", "gemini") and enable_determinism:
            # Use default seed for OpenAI/Gemini models to ensure determinism (only if enabled)
            seed = 123

        # build messages for chat completion style providers
        messages = new_conversation.serialize()
        if self.provider == "openai":
            output_text = self._generate_openai_responses(
                input=new_conversation.history,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                max_completion_tokens=max_completion_tokens,
                reasoning_effort=reasoning_effort,
                response_format=response_format,
                top_p=top_p,
                n=n,
                seed=seed,
                extra_params=extra_params
            )
        elif self.provider == "gemini":
            output_text = self._generate_gemini(
                input=new_conversation.history,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                max_completion_tokens=max_completion_tokens,
                top_p=top_p,
                seed=seed,
            )
        elif self.provider in ("sglang", "local"):
            output_text = self._generate_openai_chat_completion(
                input=new_conversation.history,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                max_completion_tokens=max_completion_tokens,
                reasoning_effort=reasoning_effort,
                response_format=response_format,
                top_p=top_p,
                n=n,
                seed=seed,
                extra_params=extra_params
            )
        elif self.provider == "aws":
            output_text = self._generate_aws_converse(
                input=new_conversation.history,
                system_prompt=system_prompt,
                temperature=temperature,
                max_tokens=max_tokens,
                reasoning_effort=reasoning_effort,
            )
        else:
            raise RuntimeError("Internal error: Provider not supported.")

        new_conversation.add_assistant_message(output_text)
        return output_text, new_conversation

    def _generate_gemini(
        self,
        input: list[dict],
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        max_completion_tokens: int | None = None,
        top_p: float | None = None,
        seed: int | None = None,
    ) -> str:
        """
        Generate a response using Google Gemini via `google-genai`.

        Note: We flatten role messages into a single text prompt. This keeps behavior
        simple and supports multi-turn history, but may not perfectly match native
        Gemini chat semantics.
        """
        if genai is None:
            raise RuntimeError("Internal error: genai is not available but provider is gemini")

        prompt_text = UnifiedLLMClient._format_messages_as_text(input, system_prompt=system_prompt)

        cfg: dict[str, Any] = {
            "temperature": 0.0 if temperature is None else float(temperature),
            "top_p": 1.0 if top_p is None else float(top_p),
            "automatic_function_calling": {"disable": True},
            "thinking_config": {"thinking_budget": 0},
        }
        # Gemini uses max_output_tokens.
        if max_completion_tokens is not None:
            cfg["max_output_tokens"] = int(max_completion_tokens)
        elif max_tokens is not None:
            cfg["max_output_tokens"] = int(max_tokens)
        if seed is not None:
            cfg["seed"] = int(seed)

        response = self.client.models.generate_content(  # type: ignore[union-attr]
            model=str(self.model_id),
            contents=prompt_text,
            config=cfg,
        )

        # Best-effort usage extraction
        usage = None
        try:
            usage_meta = getattr(response, "usage_metadata", None)
            if usage_meta is not None:
                pt = getattr(usage_meta, "prompt_token_count", None)
                ct = getattr(usage_meta, "candidates_token_count", None)
                tt = getattr(usage_meta, "total_token_count", None)
                usage = {}
                if pt is not None:
                    usage["prompt_tokens"] = int(pt)
                if ct is not None:
                    usage["completion_tokens"] = int(ct)
                if tt is not None:
                    usage["total_tokens"] = int(tt)
                if usage == {}:
                    usage = None
        except Exception:
            usage = None

        self.usage_tracker.add(
            provider=self.provider,
            model=str(self.model_id),
            usage=usage,
        )

        # Extract text (match unified_judge_framework.py logic)
        if hasattr(response, "text") and response.text:
            return response.text
        if hasattr(response, "candidates") and response.candidates:
            candidate = response.candidates[0]
            if hasattr(candidate, "content"):
                if hasattr(candidate.content, "parts") and candidate.content.parts:
                    return candidate.content.parts[0].text
                if hasattr(candidate.content, "text"):
                    return candidate.content.text
        return str(response) if response else ""

    def _generate_openai_responses(
        self,
        input: list[dict],
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        max_completion_tokens: int | None = None,
        reasoning_effort: str | None = None,
        response_format: dict | None = None,
        top_p: float | None = None,
        n: int | None = None,
        seed: int | None = None,
        extra_params: dict | None = None
    ) -> tuple[str, UnifiedLLMClient.Conversation]:
        """
        Generate a response using OpenAI's Chat Completions API (supports response_format).
        """
        messages = []
        if system_prompt is not None:
            messages.append({"role": "system", "content": system_prompt})
        messages += input
        request_params = {
            "model": self.model_id,
            "messages": messages,
        }
        # GPT-5 models don't support temperature parameter (only default value 1 is supported)
        # Skip temperature for GPT-5 series models to avoid errors
        if temperature is not None and not (self.model.startswith("gpt-5") or self.model_id.startswith("gpt-5")):
            request_params["temperature"] = temperature
        # Prefer max_completion_tokens over max_tokens (newer API)
        if max_completion_tokens is not None:
            request_params["max_completion_tokens"] = max_completion_tokens
        elif max_tokens is not None:
            request_params["max_completion_tokens"] = max_tokens
        if reasoning_effort is not None:
            request_params["reasoning_effort"] = reasoning_effort
        if response_format is not None:
            request_params["response_format"] = response_format
        if top_p is not None:
            request_params["top_p"] = top_p
        if n is not None:
            request_params["n"] = n
        if seed is not None:
            request_params["seed"] = seed
        if extra_params:
            request_params.update(extra_params)
        response = self.client.chat.completions.create(**request_params)
        self.usage_tracker.add(
            provider=self.provider,
            model=str(self.model_id),
            usage=self._extract_openai_usage(response),
        )
        return response.choices[0].message.content

    def _generate_openai_chat_completion(
        self,
        input: list[dict],
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        max_completion_tokens: int | None = None,
        reasoning_effort: str | None = None,
        response_format: dict | None = None,
        top_p: float | None = None,
        n: int | None = None,
        seed: int | None = None,
        extra_params: dict | None = None
    ) -> tuple[str, UnifiedLLMClient.Conversation]:
        """
        Generate a response using OpenAI's legacy Chat Completion API:
        https://platform.openai.com/docs/api-reference/chat/create

        Sglang doesn't support the new Responses API yet.
        """
        messages = []
        if system_prompt is not None:
            messages.append({"role": "system", "content": system_prompt})
        messages += input
        request_params = {
            "model": self.model_id,
            "messages": messages,
        }
        if temperature is not None:
            request_params["temperature"] = temperature
        # Prefer max_completion_tokens over max_tokens (newer API)
        if max_completion_tokens is not None:
            request_params["max_completion_tokens"] = max_completion_tokens
        elif max_tokens is not None:
            request_params["max_completion_tokens"] = max_tokens
        if reasoning_effort is not None:
            request_params["reasoning_effort"] = reasoning_effort
        if response_format is not None:
            request_params["response_format"] = response_format
        if top_p is not None:
            request_params["top_p"] = top_p
        if n is not None:
            request_params["n"] = n
        if seed is not None:
            request_params["seed"] = seed
        if extra_params:
            request_params.update(extra_params)
        response = self.client.chat.completions.create(**request_params)
        # SGLang is OpenAI-compatible; usage may or may not be present.
        self.usage_tracker.add(
            provider=self.provider,
            model=str(self.model_id),
            usage=self._extract_openai_usage(response),
        )
        return response.choices[0].message.content

    def _generate_aws_converse(
        self,
        input: list[dict],
        system_prompt: str | None = None,
        temperature: float | None = None,
        max_tokens: int | None = None,
        reasoning_effort: str | None = None,
    ) -> tuple[str, UnifiedLLMClient.Conversation]:
        """
        Generate a response using AWS's new Converse API (released May 2024):
        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
        """
        # By default, input has following structure:
        #   [ {"role": "user", "content": "Hi"} ]
        # Bedrock requires the following structure:
        #   [ {"role": "user", "content": [{"text": "Hi"}]} ]
        messages = []
        for entry in input:
            messages.append(
                {
                    "role": entry.get("role"),
                    "content": [{"text": entry.get("content")}],
                }
            )
        request_params = {
            "modelId": self.model_id,
            "messages": messages,
        }
        if system_prompt is not None:
            request_params["system"] = [{"text": system_prompt}]
        inference_config = {}
        if temperature is not None:
            inference_config["temperature"] = temperature
        if max_tokens is not None:
            inference_config["maxTokens"] = max_tokens
        if inference_config:
            request_params["inferenceConfig"] = inference_config
        if reasoning_effort is not None:
            print("[WARNING] reasoning_effort ignored for aws")
        response = self.client.converse(**request_params)
        return response["output"]["message"]["content"][0]["text"]


if __name__ == "__main__":
    # Sample usage in multi-turn conversation
    client = UnifiedLLMClient("gpt-5-mini")

    # only need once per conversation, but safe to pass every time
    prompt_sys = "Use all uppercase letters"

    print("----------------1st turn----------------")
    prompt = "Tell me a joke"
    response, conv = client.generate(prompt, system_prompt=prompt_sys)
    print(response)

    print("----------------2nd turn----------------")
    prompt = "In a sentence, why is that joke funny?"
    response, conv = client.generate(prompt, conversation=conv)
    print(response)

    print("----------conversation history----------")
    print(conv)
