from openai import AsyncOpenAI
import os
import json
import httpx
from functools import partial
import logging
import copy
from rich import print
from .base_client import ModelResponseBase

class SGLangBase(ModelResponseBase):
    def __init__(self, name, in_token_costs=0, out_token_costs=0, api_config=None, **kwargs):
        super().__init__(
            name=name,
            in_token_costs=in_token_costs,
            out_token_costs=out_token_costs,
            api_config=api_config
        )
        self.api_config=None
        host = kwargs['host']
        port = kwargs['port']
        base_url=f"http://{host}:{port}/v1"
        self._client = AsyncOpenAI(
            api_key="None",
            base_url=base_url,
        )
        self.temperature = kwargs.pop("temperature", 0.7)
        self.top_p = kwargs.get("top_p", 0.95)
        self.presence_penalty = kwargs.get("presence_penalty", 0.6)
        # self.extra_body={"continue_final_message": kwargs.pop("continue_final_message", False)}
        self.extra_body={}

    @property
    def client(self):
        return self._client

    @client.setter
    def client(self, value):
        self._client = value

    @staticmethod
    def _extract_content(response):
        # print(response)
        try:
            return {"text": response.message.content, "reasoning": getattr(response, 'reasoning', '')}
        except AttributeError:
            return {"text": getattr(response, 'message', {}).get('content', ''), "reasoning": ""}

    async def query_response(self, messages: list, **kwargs) -> str:
        if "max_new_tokens" in kwargs:
            max_tokens = kwargs.pop("max_new_tokens")
        elif "max_tokens" in kwargs:
            max_tokens = kwargs.pop("max_tokens")
        else:
            logging.warning("max_new_tokens not specified, using default 128")
            max_tokens = 128

        _ = kwargs.pop("add_think", False)

        return await super().query_response(
            messages,
            max_tokens=max_tokens,
            presence_penalty=self.presence_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            extra_body=self.extra_body,
            **kwargs
        )
    
    async def create_response(self, messages, **kwargs):
        # print(messages)
        # print(kwargs)
        tmp_kwargs = copy.deepcopy(kwargs)
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")
        # print(kwargs, self.extra_headers)
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            temperature=self.temperature,
            top_p=self.top_p,
            presence_penalty=self.presence_penalty,
            max_completion_tokens=max_tokens,
            extra_body={"continue_final_message": continue_final_message},  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        )
        # print(resp)
        return resp


def get_default_config(config_path, type="openrouter"):
    file_dir = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(file_dir, "api_config.json"), "r") as fp:
        config = json.load(fp)[type]
    return config


class OpenRouterBase(ModelResponseBase):
    def __init__(self, name, in_token_costs, out_token_costs, api_config, **kwargs):
        self.api_config = get_default_config(api_config)
        super().__init__(
            name=name,
            in_token_costs=in_token_costs,
            out_token_costs=out_token_costs,
            api_config=self.api_config,
        )
        self._client = AsyncOpenAI(
            api_key=self.api_config["api_key"],
            base_url=self.api_config["base_url"],
            http_client=httpx.AsyncClient(verify=False),
        )
        self.temperature = kwargs.pop("temperature", 0.7)
        self.top_p = kwargs.get("top_p", 0.95)
        self.presence_penalty = kwargs.get("presence_penalty", 0.6)
        self.extra_body = kwargs.get(
            "extra_body",
            {
                "top_k": 20,
                "min_p": 0,
                "provider": {"sort": "throughput"},
                # "reasoning": {"exclude": True}
                "reasoning": {
                    # "max_tokens": kwargs.get("reasoning", {}).get("max_tokens", 2000),
                    "max_tokens": kwargs.get("reasoning", {}).get("max_tokens", 10),
                    # "effort": "minimal",
                    "exclude": False
                }
            },
        )

    @property
    def client(self):
        return self._client

    @client.setter
    def client(self, value):
        self._client = value

    @staticmethod
    def _extract_content(response):
        # print(response)
        try:
            return {"text": response.message.content, "reasoning": getattr(response, 'reasoning', '')}
        except AttributeError:
            return {"text": getattr(response, 'message', {}).get('content', ''), "reasoning": ""}

    async def query_response(self, messages: list, **kwargs) -> str:
        if "max_new_tokens" in kwargs:
            max_tokens = kwargs.pop("max_new_tokens")
        elif "max_tokens" in kwargs:
            max_tokens = kwargs.pop("max_tokens")
        elif "max_completion_tokens" in kwargs:
            max_tokens = kwargs.pop("max_completion_tokens")
        else:
            raise RuntimeError("max_completion_tokens not specified")

        _ = kwargs.pop("add_think", False)

        return await super().query_response(
            messages,
            max_tokens=max_tokens,
            presence_penalty=self.presence_penalty,
            temperature=self.temperature,
            top_p=self.top_p,
            extra_body=self.extra_body,
            **kwargs
        )


class OpenRouterFreeModel(OpenRouterBase):
    def __init__(self, name, api_config: str, **kwargs):
        super().__init__(
            name=name,
            in_token_costs=0.0,
            out_token_costs=0.0,
            api_config=api_config,
            **kwargs
        )


_free_model_name_map = {
    "glm-z1-9b:free": "thudm/glm-z1-9b:free",
}


class GLM_Z1_32b(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="thudm/glm-z1-32b",
            in_token_costs=0.24,
            out_token_costs=0.24,
            api_config=api_config,
            **kwargs
        )


class Qwen3_8b(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="qwen/qwen3-8b",
            in_token_costs=0.035,
            out_token_costs=0.138,
            api_config=api_config,
            **kwargs
        )


class Qwen3_32b(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="qwen/qwen3-32b",
            in_token_costs=0.1,
            out_token_costs=0.4,
            api_config=api_config,
            **kwargs
        )


class Qwen3_235b(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="qwen/qwen3-235b-a22b",
            in_token_costs=0.14,
            out_token_costs=0.85,
            api_config=api_config,
        )


class Qwen3_30b_2507(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="qwen/qwen3-30b-a3b-instruct-2507",
            in_token_costs=0.2,
            out_token_costs=0.8,
            api_config=api_config,
            **kwargs
        )

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        temperature = tmp_kwargs.pop("temperature", 0.7)
        top_p = tmp_kwargs.pop("top_p", 0.95)
        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")
        # print(kwargs, self.extra_headers)
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        extra_body={"continue_final_message": continue_final_message}  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        provider = tmp_kwargs.pop("provider", None)
        if provider is not None:
            extra_body['provider'] = provider
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            extra_headers=self.extra_headers,
            temperature=temperature,
            top_p=top_p,
            max_completion_tokens=max_tokens,
            extra_body=extra_body
        )
        # print(resp)
        return resp

class Qwen3_235b_2507(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="qwen/qwen3-235b-a22b-2507",
            in_token_costs=0.078,
            out_token_costs=0.312,
            api_config=api_config,
        )

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        temperature = tmp_kwargs.pop("temperature", 0.7)
        top_p = tmp_kwargs.pop("top_p", 0.95)
        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")
        # print(kwargs, self.extra_headers)
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        extra_body={"continue_final_message": continue_final_message}  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        provider = tmp_kwargs.pop("provider", None)
        if provider is not None:
            extra_body['provider'] = provider
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            extra_headers=self.extra_headers,
            temperature=temperature,
            top_p=top_p,
            max_completion_tokens=max_tokens,
            extra_body=extra_body
        )
        # print(resp)
        return resp



class Qwen3_30b_a3b(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="qwen/qwen3-30b-a3b",
            in_token_costs=0.02,
            out_token_costs=0.08,
            api_config=api_config,
        )
        # self.extra_body=None

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        temperature = tmp_kwargs.pop("temperature", 0.7)
        top_p = tmp_kwargs.pop("top_p", 0.95)
        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")
        # print(kwargs, self.extra_headers)
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        extra_body={"continue_final_message": continue_final_message,
                    "reasoning": {"exclude": True}}  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        # provider = tmp_kwargs.pop("provider", {"only": ["Chutes", "deepinfra/fp8", "nebius/fp8", "parasail/fp8", "siliconflow/fp8", "novita/fp8", "friendli"]})
        provider = tmp_kwargs.pop("provider", {"ignore": ["SiliconFlow"]})  # SiliconFlow doesn't respect the max_tokens
        if provider is not None:
            extra_body['provider'] = provider
        # print(kwargs, extra_body)
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            extra_headers=self.extra_headers,
            temperature=temperature,
            top_p=top_p,
            max_completion_tokens=max_tokens,
            extra_body=extra_body
        )
        # print(resp)
        return resp


class Llama4Scout(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="meta-llama/llama-4-scout",
            in_token_costs=0.08,
            out_token_costs=0.3,
            api_config=api_config,
            **kwargs
        )

    async def create_response(self, messages, **kwargs):
        tmp_kwargs = copy.deepcopy(kwargs)
        temperature = tmp_kwargs.pop("temperature", 0.7)
        top_p = tmp_kwargs.pop("top_p", 0.95)
        max_tokens = max([tmp_kwargs.pop("max_tokens", 0), tmp_kwargs.pop("max_completion_tokens", 0), tmp_kwargs.pop("max_new_tokens", 0)])
        if max_tokens == 0:
            raise ValueError(f"max_tokens cannot be zero.")
        # print(kwargs, self.extra_headers)
        continue_final_message = tmp_kwargs.pop("continue_final_message", False)
        extra_body={"continue_final_message": continue_final_message}  # use prefill logic https://github.com/sgl-project/sglang/tree/main/examples/runtime; not supported in openrouter
        provider = tmp_kwargs.pop("provider", None)
        if provider is not None:
            extra_body['provider'] = provider
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            extra_headers=self.extra_headers,
            temperature=temperature,
            top_p=top_p,
            max_completion_tokens=max_tokens,
            extra_body=extra_body
        )
        # print(resp)
        return resp


class Claude_3_7_Sonnet(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="anthropic/claude-3.7-sonnet",
            in_token_costs=3.0,
            out_token_costs=15.0,
            api_config=api_config,
        )


class Gemini_2_5_Pro(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="google/gemini-2.5-pro-preview",
            in_token_costs=1.25,
            out_token_costs=10.0,
            api_config=api_config,
        )


class GPT_4_1_Nano(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-4.1-nano",
            in_token_costs=0.1,
            out_token_costs=0.4,
            api_config=api_config,
        )


class GPT_4_1_Mini(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-4.1-mini",
            in_token_costs=0.4,
            out_token_costs=1.6,
            api_config=api_config,
        )


class GPT_4o_Mini(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-4o-mini",
            in_token_costs=0.15,
            out_token_costs=0.6,
            api_config=api_config,
        )

class O4_Mini(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/o4-mini",
            in_token_costs=1.1,
            out_token_costs=4.4,
            api_config=api_config,
        )
class GPT_OSS_20B(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-oss-20b",
            in_token_costs=0.05,
            out_token_costs=0.2,
            api_config=api_config,
        )

class GPT_OSS_120B(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-oss-120b",
            in_token_costs=0.15,
            out_token_costs=0.6,
            api_config=api_config,
        )

class GPT_5_nano(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-5-nano",
            in_token_costs=0.05,
            out_token_costs=0.4,
            api_config=api_config,
        )


class GPT_5_mini(OpenRouterBase):
    def __init__(self, api_config: str, **kwargs):
        super().__init__(
            name="openai/gpt-5-mini",
            in_token_costs=0.25,
            out_token_costs=2,
            api_config=api_config,
        )

openrouter_model_map = {
    key: partial(OpenRouterFreeModel, name)
    for key, name in _free_model_name_map.items()
}
openrouter_model_map.update(
    {
        "OR-claude-3.7-sonnet": Claude_3_7_Sonnet,
        "OR-gemini-2.5-pro": Gemini_2_5_Pro,
        "OR-qwen3-235b": Qwen3_235b,
        "OR-qwen3-32b": Qwen3_32b,
        "OR-qwen3-8b": Qwen3_8b,
        "OR-qwen3-30b-a3b": Qwen3_30b_a3b,
        "OR-qwen3-30b-a3b-2507-instruct": Qwen3_30b_2507,
        "OR-qwen3-235b-a22b-2507-instruct": Qwen3_235b_2507,
        "OR-glm-z1-32b": GLM_Z1_32b,
        "OR-o4-mini": O4_Mini,
        "OR-gpt-4.1-nano": GPT_4_1_Nano,
        "OR-gpt-4.1-mini": GPT_4_1_Mini,
        "OR-gpt-4o-mini": GPT_4o_Mini,
        "OR-gpt-oss-20b": GPT_OSS_20B,
        "OR-gpt-oss-120b": GPT_OSS_120B,
        "OR-gpt-5-nano": GPT_5_nano,
        "OR-gpt-5-mini": GPT_5_mini,
        "OR-llama-4-scout": Llama4Scout,
        "SGLang": SGLangBase
    }
)

if __name__ == "__main__":
    import asyncio
    from .utils import get_default_config

    async def unittest():
        prompt = [{"role": "user", "content": "What model are you? What is 2+5?"}]

        config = get_default_config("openrouter")
        model = O4_Mini(config)
        out = await model.query_response(prompt, n=1, return_extras=True)
        print(f"o4 mini: {out}")
        
        model = GPT_OSS_20B(config)
        out = await model.query_response(prompt, n=1, return_extras=True)
        print(f"gpt oss 20B: {out}")

        ### Test all models
        # for k, v in openrouter_model_map.items():
        #     print("--------", k, "----------")
        #     model = v(config)
        #     out = await model.query_response("What model are you? What is 2+5?", n=2)
        #     print(out)

    asyncio.run(unittest())
