import copy
import logging
from os import getenv
from time import sleep
from typing import Any

import backoff
import regex

from ._find import Find
from ._gen import Gen
from ._select import Select
from .backend import PathFinder
from .model import Model
from .trie import MarisaTrie, Trie

logger = logging.getLogger(__name__)

def can_be_int(s):
    try:
        int(s)  # Try converting `s` to int
        return True
    except ValueError:
        return False  # Return False if a ValueError is raised


class ModelAPI(PathFinder):
    token_in = 0
    token_out = 0

    def __init__(self, model_name: str, seed: int, api_assistant: bool = True) -> None:
        super().__init__(model_name)

        self.temperature: float = 0.0
        self.top_p: float = 1.0
        self.max_tokens: int = 1000
        self.seed: int = seed

        self.prefix_text: str = ""
        self.api_assistant: bool = api_assistant

    def _current_prompt(self):
        if isinstance(self.chat, list):
            prompt_render = str(self.chat)
        else:
            prompt_render = self.chat
        return prompt_render

    def _consume_assistant_text(self, value):
        self.prefix_text += value
        match = regex.match(
            r"(.*?)" + regex.escape(value) + r"(.*?)",
            self.text_to_consume,
            regex.DOTALL,
        )
        if match:
            self.text_to_consume = self.text_to_consume[len(match.group()) :]
            self.prefix_text = ""
        else:
            self.text_to_consume = ""

    def _get_gen(self, value: Gen):
        self.temperature = value.temperature
        self.top_p = value.top_p
        self.max_tokens = value.max_tokens
        if value.stop_regex is None:
            r = r"(.*?)"
        else:
            r = rf"(.*?)({value.stop_regex})"

        return self.run(self, r, value.name, True, value.save_stop_text)

    def _get_find(self, value: Find):
        self.temperature = value.temperature
        self.top_p = value.top_p
        return self.run_find(self, value.regex, value.name)

    def _get_select(self, value: Select):
        if all(can_be_int(x) for x in value.options):
            r = r"(\d+)"
        else:
            r = r"("
            r += r"|".join([regex.escape(o) for o in value.options])
            r += r")"
        return self.run(self, r, value.name, False, False)

    def request_api(self, chat, temperature, top_p, max_tokens):
        raise NotImplementedError

    def run_find(self, lm, r, name):
        if lm.text_to_consume == "":
            tmp_chat = (
                lm.chat[:-1]
                if lm.chat[-1]["role"] == "assistant" and lm.chat[-1]["content"] == ""
                else lm.chat
            )
            if self.api_assistant:
                lm.text_to_consume, self.reasoning = self.request_api(
                    tmp_chat, lm.temperature, lm.top_p, lm.max_tokens
                )
            else:
                tmp_chat = (
                    tmp_chat[:-1] if tmp_chat[-1]["role"] == "assistant" else tmp_chat
                )
                lm.text_to_consume, self.reasoning = self.request_api(
                    tmp_chat, lm.temperature, lm.top_p, lm.max_tokens
                )
                match = regex.match(
                    regex.escape(lm.prefix_text) + r"(.*?)",
                    lm.text_to_consume,
                    regex.DOTALL,
                )
                if match:
                    lm.text_to_consume = lm.text_to_consume[len(match.group()) :]
                    lm.prefix_text = ""

        original_res = lm.text_to_consume
        match = regex.search(r, lm.text_to_consume)
        if match:
            res = match.group(0)
            lm._variables[name] = res
            return res, original_res
        else:
            raise ValueError(f"Regex {r} not found in {lm.text_to_consume}")

    def run(self, lm, r, name, is_gen, save_stop_text):
        if lm.text_to_consume == "":
            tmp_chat = (
                lm.chat[:-1]
                if lm.chat[-1]["role"] == "assistant" and lm.chat[-1]["content"] == ""
                else lm.chat
            )
            if self.api_assistant:
                lm.text_to_consume, self.reasoning = self.request_api(
                    tmp_chat, lm.temperature, lm.top_p, lm.max_tokens
                )
            else:
                tmp_chat = (
                    tmp_chat[:-1] if tmp_chat[-1]["role"] == "assistant" else tmp_chat
                )
                lm.text_to_consume, self.reasoning = self.request_api(
                    tmp_chat, lm.temperature, lm.top_p, lm.max_tokens
                )
                match = regex.match(
                    regex.escape(lm.prefix_text) + r"(.*?)",
                    lm.text_to_consume,
                    regex.DOTALL,
                )
                if match:
                    lm.text_to_consume = lm.text_to_consume[len(match.group()) :]
                    lm.prefix_text = ""

            # remove any prefix, if any
            p = lm.chat[-1]["content"].strip()
            if lm.text_to_consume.startswith(p):
                lm.text_to_consume = lm.text_to_consume[len(p) :]

        if regex.search(r, lm.text_to_consume):
            match = regex.match(r + r"(.*?)", lm.text_to_consume, regex.DOTALL)
            if match:
                # complete match
                match_res = match.group()
                if save_stop_text:
                    res = match.group()
                    lm.text_to_consume = lm.text_to_consume[len(match_res) :]
                else:
                    res = match.group(1)
                    lm.text_to_consume = lm.text_to_consume[len(match.group(1)) :]
            else:
                match = regex.findall(r, lm.text_to_consume, regex.DOTALL)[0]
                lm.text_to_consume = ""  # reset since this was a search of the response
                res = match
        elif is_gen:
            # not stop token
            res = lm.text_to_consume
            lm.text_to_consume = ""
        else:
            raise ValueError(f"Cant find {r} in {lm.text_to_consume}")
        return res


class OpenAIAPI(ModelAPI):
    def __init__(self, model_name, seed):
        super().__init__(model_name, seed)
        from openai import OpenAI

        self.client = OpenAI()

    def request_api(self, chat, temperature, top_p, max_tokens):
        import openai

        @backoff.on_exception(backoff.expo, openai.RateLimitError)
        def completions_with_backoff(**kwargs):
            return self.client.chat.completions.create(**kwargs)

        out = completions_with_backoff(
            model=self.model_name,
            messages=chat,
            temperature=temperature,
            top_p=top_p,
            seed=self.seed,
            max_tokens=max_tokens,
        )
        logger.info(f"OpenAI system_fingerprint: {out.system_fingerprint}")
        return out.choices[0].message.content, ""

import uuid
API_KEYS = ["OPENROUTER_API_KEY"]#, "OPENROUTER_API_KEY_1", "OPENROUTER_API_KEY_2", "OPENROUTER_API_KEY_3"]

class OpenRouter(ModelAPI):
    def __init__(self, model_name, seed):
        super().__init__(model_name, seed)
        from openai import OpenAI

        self.client = OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=getenv(API_KEYS[0]),#.pop(0)),
        )
        self.random_name = str(uuid.uuid4())

    def request_api(self, chat, temperature, top_p, max_tokens):
        import openai
        @backoff.on_exception(backoff.expo, openai.RateLimitError)
        def completions_with_backoff(**kwargs):
            out = self.client.chat.completions.create(**kwargs)
            logger.debug(f"Out: {out}")
            while out.choices is None:
                msg = out.error["message"].lower()
                if "free-models-per-min" in msg or "provider returned error" in msg:
                    sleep(60)
                elif len(API_KEYS) > 0:
                    logger.debug("Switching openrouter API_KEY")
                    self.client.api_key = getenv(API_KEYS.pop(0))
                else:
                    raise ConnectionRefusedError(msg)
                out = self.client.chat.completions.create(**kwargs)
            return out

        out = completions_with_backoff(
            model=self.model_name,
            messages=chat,
            temperature=temperature,
            top_p=top_p,
            seed=self.seed,
            max_tokens=max_tokens,
        )
        logger.info(f"OpenAI system_fingerprint: {out.system_fingerprint}")
        if out.usage is not None:
            token_in = out.usage.prompt_tokens
            token_out = out.usage.completion_tokens
#            append_token_usage(token_in, token_out, self.model_name, self.random_name, openrouter=True)
        else:
            logger.warning("out.usage was None.")
        content = out.choices[0].message.content
        reasoning = out.choices[0].message.reasoning if hasattr(out.choices[0].message, "reasoning") else "" 
        return content, reasoning


import json
import os

class TokenCounter:
    total_in = 0
    total_out = 0
    total_cost = 0

    @classmethod
    def add_tokens(cls, token_in: int, token_out: int, cost: float) -> None:
        TokenCounter.total_in += token_in
        TokenCounter.total_out += token_out
        TokenCounter.total_cost += cost
        if TokenCounter.total_cost > 2:
            raise AssertionError("API cost exceeded limit.")

    @classmethod
    def log_total(cls) -> None:
        logger.info(f"Total token in: {TokenCounter.total_in}, total token out: {TokenCounter.total_out}, total cost: {TokenCounter.total_cost}")


def append_token_usage(token_in, token_out, model, file_name, openrouter=False):

    # Create a dictionary to append
    if openrouter:
        file_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "..", "..", "api_usage_openrouter"
        )
    else:
        file_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), "..", "..", "api_usage"
        )
    os.makedirs(file_path, exist_ok=True)
    file_path = os.path.join(file_path, f"{file_name}.json")

    if "gpt-4-turbo-2024-04-09" in model:
        cost_in = token_in * 10 / 1e6
        cost_out = token_out * 30 / 1e6
    elif "gpt-4o-mini-2024-07-18" in model:
        cost_in = token_in * 0.15 / 1e6
        cost_out = token_out * 0.6 / 1e6
    elif "z-gpt-4o-2024-08-0" in model:
        cost_in = token_in * 2.5 / 1e6
        cost_out = token_out * 10 / 1e6
    elif "openai/gpt-4o-2024-08-06" in model:
        cost_in = token_in * 2.5 / 1e6
        cost_out = token_out * 10 / 1e6
    elif "gpt-o3-mini-2025-01-31" in model:
        cost_in = token_in * 1.1 / 1e6
        cost_out = token_out * 4.4 / 1e6
    elif "z-gpt-o4-mini-2025-04-16" in model:
        cost_in = token_in * 1.1 / 1e6
        cost_out = token_out * 4.4 / 1e6
    elif "deepseek/deepseek-chat-v3-0324" in model:
        cost_in = token_in * 0.27 / 1e6
        cost_out = token_out * 1.1 / 1e6
    elif "deepseek/deepseek-r1" in model:
        cost_in = token_in * 0.5 / 1e6
        cost_out = token_out * 2.18 / 1e6
    elif "google/gemini-2.5-flash-preview:thinking" in model:
        cost_in = token_in * 0.15 / 1e6
        cost_out = token_out * 3.5 / 1e6
    elif "google/gemini-2.5-flash-preview" in model:
        cost_in = token_in * 0.15 / 1e6
        cost_out = token_out * 0.6 / 1e6
    elif "qwen/qwen3-235b-a22b" in model:
        cost_in = token_in * 0.1 / 1e6
        cost_out = token_out * 0.1 / 1e6
    elif "anthropic/claude-3.7-sonnet" in model:
        cost_in = token_in * 3 / 1e6
        cost_out = token_out * 15 / 1e6
    elif "meta-llama/llama-3.3-70b-instruct" in model:
        cost_in = token_in * 0.1 / 1e6
        cost_out = token_out * 0.25 / 1e6
    else:
        raise ValueError(f"Model {model} not supported")

    new_entry = {
        "model": model,
        "token_in": token_in,
        "token_out": token_out,
        "cost_in": cost_in,
        "cost_out": cost_out,
        "total_cost": cost_in + cost_out,
    }
    TokenCounter.add_tokens(token_in, token_out, cost_in + cost_out)
    TokenCounter.log_total()

    # Check if file exists and is not empty
    if os.path.isfile(file_path) and os.path.getsize(file_path) > 0:
        # Open the file in read/write mode
        with open(file_path, "r+") as file:
            file.seek(0, os.SEEK_END)  # Move the cursor to the end of file
            file.seek(
                file.tell() - 1, os.SEEK_SET
            )  # Move the cursor back by one character from the end
            file.write(
                ",\n" + json.dumps(new_entry) + "\n]"
            )  # Convert dict to JSON string and append
    else:
        # If the file does not exist or is empty, write a new list with one element
        with open(file_path, "w") as file:
            json.dump([new_entry], file, indent=4)




class AzureOpenAIAPI(ModelAPI):
    def __init__(self, model_name, seed):
        super().__init__(model_name, seed)
        from openai import AzureOpenAI

        self.client = AzureOpenAI()
        self.random_name = str(uuid.uuid4())

    def request_api(self, chat, temperature, top_p, max_tokens):
        import openai

        @backoff.on_exception(backoff.expo, openai.RateLimitError)
        def completions_with_backoff(**kwargs):
            return self.client.chat.completions.create(**kwargs)
        if self.model_name == "z-gpt-o3-mini-2025-01-31" or self.model_name == "z-gpt-o4-mini-2025-04-16":
            generation_args = {"max_completion_tokens": max_tokens}
        else:
            generation_args = {"temperature": temperature, "max_tokens": max_tokens}
        max_retries = 2
        retries = 0
        while True:
            try:
                out = completions_with_backoff(
                    model=self.model_name,
                    messages=chat,
                    top_p=top_p,
                    seed=self.seed,
                    **generation_args,
                )
                logger.info(f"OpenAI system_fingerprint: {out.system_fingerprint}")

                self.token_in = out.usage.prompt_tokens
                self.token_out = out.usage.completion_tokens

#                append_token_usage(
#                    self.token_in, self.token_out, self.model_name, self.random_name
#                )

                res = out.choices[0].message.content
                if retries == max_retries or res is not None:
                    break
            except openai.BadRequestError as e:
                logger.warning(e)
            retries += 1


        return res, ""


import os


class MistralAPI(ModelAPI):
    def __init__(self, model_name, seed):
        super().__init__(model_name, seed, api_assistant=False)
        from httpx import Client as HTTPClient
        from httpx import HTTPTransport
        from mistralai.client import MistralClient

        api_key = os.environ["MISTRAL_API_KEY"]
        self.client = MistralClient(api_key=api_key)

        http_proxies = [
            proxy
            for varname, proxy in os.environ.items()
            if varname.lower() == "http_proxy"
        ]
        https_proxies = [
            proxy
            for varname, proxy in os.environ.items()
            if varname.lower() == "https_proxy"
        ]
        all_proxies = [
            proxy
            for varname, proxy in os.environ.items()
            if varname.lower() == "all_proxy"
        ]
        proxies = {
            "http://": http_proxies[0] if len(http_proxies) > 0 else None,
            "https://": https_proxies[0] if len(https_proxies) > 0 else None,
            "all://": all_proxies[0] if len(all_proxies) > 0 else None,
        }

        self.client._client = HTTPClient(
            proxies=proxies,
            follow_redirects=True,
            timeout=self.client._timeout,
            transport=HTTPTransport(retries=self.client._max_retries),
        )

    def request_api(self, chat, temperature, top_p, max_tokens):
        from mistralai.exceptions import MistralException

        @backoff.on_exception(backoff.expo, MistralException)
        def completions_with_backoff(**kwargs):
            return self.client.chat(**kwargs)

        from mistralai.models.chat_completion import ChatMessage

        if chat[-1]["role"] == "assistant":
            raise Exception(
                "Assistant should not be the last role in the chat for Mistral."
            )

        chat_mistral = [
            ChatMessage(role=entry["role"], content=entry["content"]) for entry in chat
        ]
        out = completions_with_backoff(
            model=self.model_name,
            messages=chat_mistral,
            temperature=temperature,
            top_p=top_p,
            random_seed=self.seed,
            max_tokens=max_tokens,
        )
        return out.choices[0].message.content, ""


class AnthropicAPI(ModelAPI):
    def __init__(self, model_name, seed):
        super().__init__(model_name, seed, api_assistant=False)
        from anthropic import Anthropic

        self.client = Anthropic()

    def request_api(self, chat, temperature, top_p, max_tokens):
        from anthropic._exceptions import APIStatusError

        @backoff.on_exception(backoff.expo, APIStatusError)
        def completions_with_backoff(**kwargs):
            return self.client.messages.create(**kwargs)

        if chat[-1]["role"] == "assistant":
            raise Exception(
                "Assistant should not be the last role in the chat for Anthropic."
            )

        out = completions_with_backoff(
            model=self.model_name,
            messages=chat,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
        )
        return out.content[0].text, ""
