import argparse
import os
import sys
from pathlib import Path

from azure.identity import (
    AuthenticationRecord,
    DeviceCodeCredential,
    TokenCachePersistenceOptions,
    get_bearer_token_provider,
)
from openai import AzureOpenAI, OpenAI

valid_models = ["gpt-4o", "ada-embeddings", "text-embedding-3-large", "meta-llama/Meta-Llama-3.1-8B-Instruct", "deepseek-ai/DeepSeek-V3", "deepseek-chat", "text-embedding-ada-002"]


class GPT:
    def __init__(
        self,
        model_name: str,
        endpoint_url: str,
        endpoint_api_key: str,
        api_version: str = "2024-02-15-preview",
        system_msg: str = "You are an AI assistant.",
        max_retries: int = 12,
        temperature: int = 1.0,
        max_tokens: int = 4096,
        top_p: float = 0.95,
        frequency_penalty: int = 0,
        presence_penalty: int = 0,
        seed: int = None,
    ):
        if model_name not in valid_models:
            raise ValueError(
                f"Invalid model: {model_name}. Valid models are: {valid_models}"
            )

        token_provider = get_bearer_token_provider(
            self._get_credential(), "https://cognitiveservices.azure.com/.default"
        )

        # self.OA_client = AzureOpenAI(
        #     azure_endpoint=endpoint_url,
        #     api_version=api_version,
        #     azure_ad_token_provider=token_provider,
        # )

        self.openai_client = OpenAI(
            api_key=endpoint_api_key,
            base_url=endpoint_url,
        )

        self.max_retries = max_retries
        self.system_msg = system_msg
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.top_p = top_p
        self.frequency_penalty = frequency_penalty
        self.presence_penalty = presence_penalty
        self.seed = seed

        # token usage logging
        self._token_logs: list[dict] = []
        self._token_totals: dict = {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
            "cost_usd": 0.0,
        }
        self._price_per_1k: dict | None = None  # {model: {"input": x, "output": y}}
        self._token_log_path: str | None = None

    def set_token_logging(self, log_path: str | None = None, price_per_1k: dict | None = None) -> None:
        """Enable per-call token usage logging.

        Args:
            log_path: If provided, append JSON lines with usage to this file.
            price_per_1k: Optional pricing map {model: {"input": price, "output": price}}.
        """
        self._token_log_path = log_path
        self._price_per_1k = price_per_1k

    def _record_usage(self, kind: str, model: str, usage_obj) -> None:
        if usage_obj is None:
            return

        prompt_tokens = getattr(usage_obj, "prompt_tokens", None)
        completion_tokens = getattr(usage_obj, "completion_tokens", None)
        total_tokens = getattr(usage_obj, "total_tokens", None)

        cost_usd = None
        if self._price_per_1k and model in self._price_per_1k and prompt_tokens is not None and completion_tokens is not None:
            input_price = self._price_per_1k[model].get("input")
            output_price = self._price_per_1k[model].get("output")
            if input_price is not None and output_price is not None:
                cost_usd = (prompt_tokens / 1000.0) * input_price + (completion_tokens / 1000.0) * output_price

        log_entry = {
            "type": kind,
            "model": model,
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": total_tokens,
            "cost_usd": cost_usd,
        }

        self._token_logs.append(log_entry)

        if prompt_tokens is not None:
            self._token_totals["prompt_tokens"] += prompt_tokens
        if completion_tokens is not None:
            self._token_totals["completion_tokens"] += completion_tokens
        if total_tokens is not None:
            self._token_totals["total_tokens"] += total_tokens
        if cost_usd is not None:
            self._token_totals["cost_usd"] += cost_usd

        if self._token_log_path:
            try:
                import json, os
                os.makedirs(os.path.dirname(self._token_log_path), exist_ok=True)
                with open(self._token_log_path, "a", encoding="utf-8") as f:
                    f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
            except Exception:
                pass

    def get_token_usage_summary(self) -> dict:
        """Return cumulative token usage and estimated cost."""
        return dict(self._token_totals)

    def set_seed(self, seed: int):
        self.seed = seed

    def _get_credential(self, lib_name: str = "azure_openai") -> DeviceCodeCredential:
        """Retrieves a credential to be used for authentication in Azure"""
        if sys.platform.startswith("win"):
            auth_record_root_path = Path(os.environ["LOCALAPPDATA"])
        else:
            auth_record_root_path = Path.home()

        auth_record_path = auth_record_root_path / lib_name / "auth_record.json"
        cache_options = TokenCachePersistenceOptions(
            name=f"{lib_name}.cache", allow_unencrypted_storage=True
        )

        if auth_record_path.exists():
            with open(auth_record_path, "r") as f:
                record_json = f.read()
            deserialized_record = AuthenticationRecord.deserialize(record_json)
            credential = DeviceCodeCredential(
                authentication_record=deserialized_record,
                cache_persistence_options=cache_options,
            )
        else:
            auth_record_path.parent.mkdir(parents=True, exist_ok=True)
            credential = DeviceCodeCredential(cache_persistence_options=cache_options)
            record_json = credential.authenticate().serialize()
            with open(auth_record_path, "w") as f:
                f.write(record_json)

        return credential

    def api_call_chat(self, messages: list[dict]) -> str | None:
        for _ in range(self.max_retries):
            completion = self.openai_client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                top_p=self.top_p,
                frequency_penalty=self.frequency_penalty,
                presence_penalty=self.presence_penalty,
                seed=self.seed if self.seed else None,
            )
            if completion:
                return completion.choices[0].message.content
        return None

    def _api_call_embedding(self, text: str) -> list[float] | None:
        for _ in range(self.max_retries):
            embedding = self.openai_client.embeddings.create(
                input=text, model=self.model_name
            )
            if embedding:
                # Azure/OpenAI embeddings may not always include usage; record if present
                usage_obj = getattr(embedding, "usage", None)
                self._record_usage("embeddings", self.model_name, usage_obj)
                return embedding.data[0].embedding
        return None

    def generate_response(self, prompt: str) -> str | None:
        """
        Generate a response for the given prompt.
        This setup can be used for GPT4 models but not for embedding genneration.
        """
        messages = [
            {
                "role": "system",
                "content": self.system_msg,
            },
            {
                "role": "user",
                "content": prompt,
            },
        ]

        response = self.api_call_chat(messages)
        return response

    def generate_embedding(self, text: str) -> list[float] | None:
        """
        Generate an embedding for the given text.
        This setup can be used for Ada embeddings but not for text generation.
        """
        embedding = self._api_call_embedding(text)
        return embedding

    def _api_call_embedding_batch(self, texts: list[str]) -> list[list[float]] | None:
        """
        Generate embeddings for a batch of texts.
        """
        for _ in range(self.max_retries):
            embedding = self.openai_client.embeddings.create(
                input=texts, model=self.model_name
            )
            if embedding:
                return [data.embedding for data in embedding.data]
        return None

    def generate_embeddings_batch(self, texts: list[str]) -> list[list[float]] | None:
        """
        Generate embeddings for a batch of texts.
        This setup can be used for Ada embeddings but not for text generation.
        
        Args:
            texts: List of text strings to generate embeddings for
            
        Returns:
            List of embeddings, where each embedding is a list of floats
        """
        if not texts:
            return []
        
        embeddings = self._api_call_embedding_batch(texts)
        return embeddings


def parser_args():
    parser = argparse.ArgumentParser(description="GPT Session")
    parser.add_argument(
        "--model_name",
        type=str,
        default="ada-embeddings",
        help="Model name to use for embedding generation",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="Embedding text",
        help="Prompt for text generation",
    )
    parser.add_argument(
        "--endpoint_url",
        type=str,
        help="Endpoint URL for the model",
    )
    parser.add_argument(
        "--endpoint_api_key",
        type=str,
        help="API key for the endpoint",
    )
    parser.add_argument(
        "--batch_mode",
        action="store_true",
        help="Enable batch mode for processing multiple texts",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parser_args()
    gpt = GPT(args.model_name, args.endpoint_url)
    response = gpt.generate_embedding(args.prompt)

    assert response is not None
