import re
import boto3
import json
from botocore import UNSIGNED
from botocore.config import Config
from typing import List
from openai import AsyncOpenAI, AsyncAzureOpenAI
import os
import httpx
import logging
import copy
import time
import asyncio
from rich import print

from .base_client import ModelResponseBase

# Constants for Claude model names
MODEL_NAMES = {
    "claude-3-5-sonnet-v1": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
    "claude-3-5-sonnet-v2": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
    "claude-3-7-sonnet-v1": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
    "claude-4-sonnet-v1": "us.anthropic.claude-sonnet-4-20250514-v1:0",
    "claude-4-opus-v1": "us.anthropic.claude-opus-4-20250514-v1:0",
}



def get_default_config(config_path, type="luxo"):
    file_dir = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(file_dir, "api_config.json"), "r") as fp:
        config = json.load(fp)[type]
    return config


class DavinciBase(ModelResponseBase):
    def __init__(self, name, in_token_costs, out_token_costs, api_config, **kwargs):
        super().__init__(
            name=name,
            in_token_costs=in_token_costs,
            out_token_costs=out_token_costs,
            api_config=api_config,
        )
        self.api_config = api_config
        self._client = AsyncAzureOpenAI(
            api_key=api_config["api_key"],
            api_version=api_config["api_version"],
            azure_endpoint=api_config["base_url"],
            http_client=httpx.AsyncClient(verify=False),
        )
        self.temperature = kwargs.pop("temperature", 0.7)
        self.top_p = kwargs.get("top_p", 0.95)

        self.extra_body = kwargs.get(
            "extra_body",
            {
                "reasoning": {
                    "max_tokens": kwargs.get("reasoning", {}).get("max_tokens", 2000),
                }
            },
        )

    @property
    def client(self):
        return self._client

    @client.setter
    def client(self, value):
        self._client = value

    @staticmethod
    def _extract_content(response):
        # Use chat completion format if available, else fallback to legacy text
        if hasattr(response, 'message') and hasattr(response.message, 'content'):
            return {"text": response.message.content, "reasoning": getattr(response, 'reasoning', '')}
        reasoning = getattr(response, 'reasoning', '')
        return {"text": getattr(response, 'text', ''), "reasoning": reasoning}

    async def query_response(self, messages: list, **kwargs) -> str:
        if "max_new_tokens" in kwargs:
            max_tokens = kwargs.pop("max_new_tokens")
        elif "max_tokens" in kwargs:
            max_tokens = kwargs.pop("max_tokens")
        elif "max_completion_tokens" in kwargs:
            max_tokens = kwargs.pop("max_completion_tokens")
        else:
            raise RuntimeError("max_completion_tokens not specified")
            # logging.warning("max_new_tokens not specified, using default 128")
            # max_tokens = 128

        _ = kwargs.pop("add_think", False)

        return await super().query_response(
            messages,
            max_tokens=max_tokens,
            temperature=self.temperature,
            top_p=self.top_p,
            **kwargs,
        )


class GPT4o(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="aide-gpt-4o",
            in_token_costs=2.5,
            out_token_costs=10.0,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class GPT4o_mini(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="aide-gpt-4o-mini",
            in_token_costs=0.15,
            out_token_costs=0.6,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class GPTo3_mini(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2025-04-01-preview"
        super().__init__(
            name="aide-o3-mini",
            in_token_costs=1.1,
            out_token_costs=4.4,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class GPTo3(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2025-04-01-preview"
        super().__init__(
            name="aide-o3",
            in_token_costs=10.0,
            out_token_costs=40.0,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class GPTo4_mini(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2025-04-01-preview"
        super().__init__(
            name="aide-o4-mini",
            in_token_costs=1.1,
            out_token_costs=4.4,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class GPT4_1(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2024-12-01-preview"
        super().__init__(
            name="aide-gpt-4.1",
            in_token_costs=2.0,
            out_token_costs=8.0,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class GPT4_1_mini(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2024-12-01-preview"
        super().__init__(
            name="aide-gpt-4.1-mini",
            in_token_costs=0.4,
            out_token_costs=1.6,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        tmp_kwargs.pop("role")
        max_output_tokens = tmp_kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
            **tmp_kwargs
        )

class GPT5_mini(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2025-04-01-preview"
        super().__init__(
            name="aide-gpt-5-mini",
            in_token_costs=0.4,
            out_token_costs=1.6,
            api_config=api_config,
        )

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        print(tmp_kwargs)
        max_output_tokens = tmp_kwargs.pop("max_tokens", 1024)
        reasoning_effort = tmp_kwargs.pop("reasoning_effort", "minimal")
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
            reasoning_effort=reasoning_effort,
        )
        print(resp)
        return resp


class GPT5_nano(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2025-04-01-preview"
        super().__init__(
            name="aide-gpt-5-nano",
            in_token_costs=0.1,
            out_token_costs=0.4,
            api_config=api_config,
        )

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        print(tmp_kwargs)
        max_output_tokens = tmp_kwargs.pop("max_tokens", 1024)
        reasoning_effort = tmp_kwargs.pop("reasoning_effort", "minimal")
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
            reasoning_effort=reasoning_effort
        )
        print(resp)
        return resp


class GPT4_1_nano(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        api_config["api_version"] = "2024-12-01-preview"
        super().__init__(
            name="aide-gpt-4.1-nano",
            in_token_costs=0.1,
            out_token_costs=0.4,
            api_config=api_config,
        )

    @staticmethod
    def _extract_content(response):
        return {"text": response.message.content, "reasoning": ""}

    async def create_response(self, messages, **kwargs):
        max_output_tokens = kwargs.pop("max_tokens", 1024)
        return await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_output_tokens,
        )


class DeepseekCoderv2(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="deepseek-coder-v2-lite-instruct",
            in_token_costs=0.8,
            out_token_costs=0.8,
            api_config=api_config,
        )
        self.client = AsyncOpenAI(
            api_key=api_config["api_key"],
            base_url="https://mlop-azure-gateway.mediatek.inc/llm/v3/models",
            http_client=httpx.AsyncClient(verify=False),
        )

    async def query_response(self, messages: list, **kwargs) -> str:
        max_new_tokens = kwargs.pop("max_new_tokens", 32768)
        return await super().query_response(
            messages, max_completion_tokens=max_new_tokens, **kwargs
        )


class DeepseekR1(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="deepseek-r1",
            in_token_costs=0.8,
            out_token_costs=0.8,
            api_config=api_config,
        )
        self.client = AsyncOpenAI(
            api_key=api_config["api_key"],
            base_url="https://mlop-azure-gateway.mediatek.inc/llm/v3/models",
            http_client=httpx.AsyncClient(verify=False),
        )

    @staticmethod
    def _extract_content(response):
        # Extract all thinking tokens
        thinking_tokens = re.findall(r".*?</think>", response.text, flags=re.DOTALL)

        # Remove thinking tokens from the response
        cleaned_text = re.sub(r".*?</think>", "", response.text, flags=re.DOTALL)
        response_text = cleaned_text.strip()

        # thinking_tokens is a list of all <think>...</think> blocks
        # If you want just the inner text, you can do:
        thinking_contents = [
            re.sub(r"</?think>", "", token) for token in thinking_tokens
        ]
        return {
            "text": response_text,
            "reasoning": "\n".join(thinking_contents),
        }

    async def query_response(
        self, prompt: str | List[str], add_think: bool = False, **kwargs
    ) -> str:
        if add_think:
            prompt = (
                prompt + "\n<think>\n"
                if isinstance(prompt, str)
                else [p + "\n<think>\n" for p in prompt]
            )
        return await super().query_response(prompt, **kwargs)


class Qwen3_30b(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="qwen3-30b-a3b",
            in_token_costs=0.0,
            out_token_costs=0.0,
            api_config=api_config,
        )
        self.client = AsyncOpenAI(
            api_key=api_config["api_key"],
            base_url="https://mlop-azure-gateway.mediatek.inc/llm/v3/models",
            http_client=httpx.AsyncClient(verify=False),
        )

    async def create_response(self, messages, **kwargs):
        # print(kwargs)
        tmp_kwargs = copy.deepcopy(kwargs)
        # print(tmp_kwargs)
        temperature = tmp_kwargs.pop("temperature", 0.7)
        top_p = tmp_kwargs.pop("top_p", 0.95)
        
        enable_thinking = tmp_kwargs.pop("enable_thinking", False)  # this follow CSES usage https://wiki.mediatek.inc/display/CSES/Reasoning+Usage
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        extra_body={
              "chat_template_kwargs": {
                    "enable_thinking": enable_thinking            # this works
                },
              "continue_final_message": continue_final_message  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        }

        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")

        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,
            extra_headers=self.extra_headers,
            max_completion_tokens=max_tokens,
            extra_body=extra_body,
            temperature=temperature,
            top_p=top_p
        )
        # print(resp)
        return resp

# Add support for deepseek-r1-distill-qwen-32b
class DeepseekR1DistillQwen32b(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="deepseek-r1-distill-qwen-32b",
            in_token_costs=0.0,
            out_token_costs=0.0,
            api_config=api_config,
        )
        self.client = AsyncOpenAI(
            api_key=api_config["api_key"],
            base_url="https://mlop-azure-gateway.mediatek.inc/llm/v3/models",
            http_client=httpx.AsyncClient(verify=False),
        )

class Llama4Scout17B(DavinciBase):
    def __init__(self, api_config: dict, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="llama4-scout-17b-16e-instruct",
            in_token_costs=0.0,
            out_token_costs=0.0,
            api_config=api_config,
        )
        # print(api_config)
        self.client = AsyncOpenAI(
            api_key=api_config["api_key"],
            base_url="https://mlop-azure-gateway.mediatek.inc/llm/v3/models",
            http_client=httpx.AsyncClient(verify=False),
        )

    async def create_response(self, messages, **kwargs):
        # print(kwargs)
        tmp_kwargs = copy.deepcopy(kwargs)
        temperature = tmp_kwargs.pop("temperature", 0.7)
        top_p = tmp_kwargs.pop("top_p", 0.95)
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")
        # print(kwargs, self.extra_headers)
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            extra_headers=self.extra_headers,
            temperature=temperature,
            top_p=top_p,
            max_completion_tokens=max_tokens,
            extra_body={"continue_final_message": continue_final_message},  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        )
        # print(resp)
        return resp

class ClaudeBase(ModelResponseBase):
    def __init__(self, name, in_token_costs, out_token_costs, api_config, model_id, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name=name,
            in_token_costs=in_token_costs,
            out_token_costs=out_token_costs,
            api_config=api_config,
        )
        self.api_config = api_config
        self.model_id = model_id
        self.bedrock_runtime = boto3.client(
            service_name="bedrock-runtime",
            region_name="us-east-1",
            endpoint_url=api_config["base_url"],
            verify=False,
            config=Config(signature_version=UNSIGNED),
        )
        
        # Add custom headers
        event_system = self.bedrock_runtime.meta.events
        event_system.register(
            "before-call.bedrock-runtime.InvokeModel",
            self._add_custom_header_before_call
        )
        self.temperature = kwargs.pop("temperature", 0.7)
        self.top_p = kwargs.get("top_p", 0.95)

    @property
    def client(self):
        return self.bedrock_runtime

    @client.setter
    def client(self, value):
        self.bedrock_runtime = value

    def _add_custom_header_before_call(self, model, params, request_signer, **kwargs):
        params["headers"]["X-User-Id"] = self.api_config['extra_headers']["X-User-Id"]
        params["headers"]["api-key"] = self.api_config["api_key"]

    @staticmethod
    def _extract_content(response):
        response_body = json.loads(response["body"].read())
        return {
            "text": response_body["content"][0]["text"],
            "reasoning": "",
        }

    async def create_response(self, messages, **kwargs):
        max_output_tokens = 256
        if "max_new_tokens" in kwargs:
            max_output_tokens = kwargs.pop("max_new_tokens")
        elif "max_tokens" in kwargs:
            max_output_tokens = kwargs.pop("max_tokens")
        else:
            logging.warning("max_new_tokens not specified, using default 256")

        _ = kwargs.pop("add_think", False)

        new_messages = copy.deepcopy(messages)
        system_context = ""
        if new_messages[0]['role'] == 'system':
            system_context = new_messages[0]['content']
            new_messages = new_messages[1:]
        request_body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": max_output_tokens,
            "temperature": self.temperature,
            'system': system_context,
            "messages": new_messages,
        })
        
        return self.bedrock_runtime.invoke_model(
            modelId=self.model_id,
            body=request_body
        )

    async def _query_single_response(self, messages, **kwargs):
        response = None
        for i in range(self.max_retries):
            if i > 0:
                print(f"Retrying... {i}")

            try:
                response = await self.create_response(messages, **kwargs)
            except Exception as e:
                print("API response error: ", e)
                print("Check if you are using the correct api profile for this model")
                continue

            if "body" not in response:
                print("Error in response\n", response)
            else:
                break

        if response is None or "body" not in response:
            raise TimeoutError("No valid response", response)

        # num_in_tokens = response.usage.prompt_tokens
        # num_out_tokens = response.usage.completion_tokens
        num_in_tokens = 0
        num_out_tokens = 0
        cost = (
            self.in_token_costs * num_in_tokens + self.out_token_costs * num_out_tokens
        ) / 1e6

        return response, (num_in_tokens, num_out_tokens, cost)


    async def query_response(
        self, messages: list, return_extras: bool = False, **kwargs
    ) -> str:
        start_time = time.time()
        response, (in_tokens, out_tokens, cost) = await self._query_single_response(messages, **kwargs)
        time_taken = time.time() - start_time
        return self._extract_content(response), (in_tokens, out_tokens, cost, time_taken)


class Claude35SonnetV1(ClaudeBase):
    def __init__(self, api_config: dict = None, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="claude-3-5-sonnet-v1",
            in_token_costs=2.0,
            out_token_costs=6.0,
            api_config=api_config,
            model_id=MODEL_NAMES["claude-3-5-sonnet-v1"],
            **kwargs
        )


class Claude35SonnetV2(ClaudeBase):
    def __init__(self, api_config: dict = None, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="claude-3-5-sonnet-v2",
            in_token_costs=2.5,
            out_token_costs=7.5,
            api_config=api_config,
            model_id=MODEL_NAMES["claude-3-5-sonnet-v2"],
            **kwargs
        )


class Claude37SonnetV1(ClaudeBase):
    def __init__(self, api_config: dict = None, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="claude-3-7-sonnet-v1",
            in_token_costs=3.0,
            out_token_costs=9.0,
            api_config=api_config,
            model_id=MODEL_NAMES["claude-3-7-sonnet-v1"],
            **kwargs
        )


class Claude4SonnetV1(ClaudeBase):
    def __init__(self, api_config: dict = None, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="claude-4-sonnet-v1",
            in_token_costs=4.0,
            out_token_costs=12.0,
            api_config=api_config,
            model_id=MODEL_NAMES["claude-4-sonnet-v1"],
            **kwargs
        )


class Claude4OpusV1(ClaudeBase):
    def __init__(self, api_config: dict = None, **kwargs):
        api_config = get_default_config(api_config)
        super().__init__(
            name="claude-4-opus-v1",
            in_token_costs=5.0,
            out_token_costs=15.0,
            api_config=api_config,
            model_id=MODEL_NAMES["claude-4-opus-v1"],
            **kwargs
        )


# Create claude_model_map
claude_model_map = {
    "claude-3-5-sonnet-v1": Claude35SonnetV1,
    "claude-3-5-sonnet-v2": Claude35SonnetV2,
    "claude-3-7-sonnet-v1": Claude37SonnetV1,
    "claude-4-sonnet-v1": Claude4SonnetV1,
    "claude-4-opus-v1": Claude4OpusV1,
}

# Combined model map
davinci_model_map = {
    "gpt-4.1": GPT4_1,
    "gpt-4.1-nano": GPT4_1_nano,
    "gpt-4.1-mini": GPT4_1_mini,
    "gpt-5-mini": GPT5_mini,
    "gpt-4o": GPT4o,
    "gpt-4o-mini": GPT4o_mini,
    "gpt-o3": GPTo3,
    "gpt-o3-mini": GPTo3_mini,
    "gpt-o4-mini": GPTo4_mini,
    "deepseek-coder-v2": DeepseekCoderv2,
    "deepseek-R1": DeepseekR1,
    "qwen3-30b": Qwen3_30b,
    "deepseek-r1-distill-qwen-32b": DeepseekR1DistillQwen32b,
    "llama4-scout-17b": Llama4Scout17B,
    **claude_model_map,  # Add Claude models
}

if __name__ == '__main__':
    question="100 + 100 ="
    messages = [dict(role="user", content=question)]

    # model = Claude37SonnetV1(None)
    # resp = model.create_response(messages)
    # print(model._extract_content(resp))

    model = Qwen3_30b(None)
    resp = model.create_response(messages)
    print(model._extract_content(resp))

