from typing import Any, List
from langchain_community.llms import LlamaCpp
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import Ollama
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import ollama
from langchain_groq import ChatGroq
from langchain_community.chat_models import ChatOllama
from langchain_together import ChatTogether
from langchain_mistralai import ChatMistralAI
from langchain_google_genai import ChatGoogleGenerativeAI



class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

class Logger(object):
    __metaclass__ = Singleton


class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

class Logger(object):
    __metaclass__ = Singleton

class customLLM(metaclass=Singleton):

    def __init__(self):
        # callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

        n_gpu_layers = 1  # Metal set to 1 is enough.
        n_batch = 512  # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
        # Make sure the model path is correct for your system!

        self.type = "orca"



        self.llm = LlamaCpp(
            model_path="ADD_PATH_HERE",
            n_gpu_layers=n_gpu_layers,
            n_batch=n_batch,
            f16_kv=True,  # MUST set to True, otherwise you will run into problem after a couple of calls
            # callback_manager=callback_manager,
            # verbose=True,  # Verbose is required to pass to the callback manager
            echo=False,
            verbose=False,
            n_ctx=8000,
            streaming=False,
            # stop = ["<|im_end|>"],
            max_tokens=1000,
        )


        # self.llm = llm

        # self.tokenizer = AutoTokenizer.from_pretrained(
        #     "Open-Orca/Mistral-7B-OpenOrca",
        #     padding_side="left",
        #     add_eos_token=True,
        #     add_bos_token=True,
        #     use_fast=False
        # )
        # self.tokenizer.pad_token = self.tokenizer.eos_token





    def orca_prompt(self, system_message: SystemMessage, human_message: HumanMessage) -> str:
        prompt = f"""<|im_start|>system
        {system_message.content}<|im_end|>
        <|im_start|>user
        {human_message.content}<|im_end|>
        <|im_start|>assistant

        """

        return prompt

    def hf_prompt(self, chat) -> str:
        return self.tokenizer.apply_chat_template(chat)

    def convert_to_str_template(self, messages: List[SystemMessage | HumanMessage | AIMessage]):
        prompt = ""

        for message in messages:
            if isinstance(message, SystemMessage):
                prompt += f"<|im_start|>system\n{message.content}<|im_end|>\n"
            elif isinstance(message, HumanMessage):
                prompt += f"<|im_start|>user\n{message.content}<|im_end|>\n"
            elif isinstance(message, AIMessage):
                prompt += f"<|im_start|>assistant\n{message.content}<|im_end|>\n"
            else:
                raise Exception("message type not supported")

        prompt += "<|im_start|>assistant\n"
        return prompt


    def convert_to_chat_template(self, messages: List[SystemMessage | HumanMessage | AIMessage]):

        chat = []

        for message in messages:
            if isinstance(message, SystemMessage):
                chat.append({"role": "system", "content": message.content})
            elif isinstance(message, HumanMessage):
                chat.append({"role": "user", "content": message.content})
            elif isinstance(message, AIMessage):
                chat.append({"role": "assistant", "content": message.content})
            else:
                raise Exception("message type not supported")

        return chat

    def __call__(self, *args: Any, **kwds: Any) -> Any:

        # print("args", args)
        # assert args is list
        # assert len(args) == 1
        # assert len(args[0]) == 2
        # # # print(type(args))
        # # # print(type(args[0]))

        # assert isinstance(args[0][0], SystemMessage)
        # assert isinstance(args[0][1], HumanMessage)

        # if self.type == "mistral":

        #     prompt = ChatPromptTemplate(messages = args[0])

        #     conversation = LLMChain(llm=self.llm, prompt=prompt, verbose=False)

        #     chat_res = conversation({})

        #     result = AIMessage(content=chat_res["text"])
        #     # print("result", result)
        #     # print("result 2", result.content)

        #     # raise Exception("test")
        #     return result

        prompt = self.convert_to_str_template(args[0])

        result = self.llm(prompt)
        # print("result", result)
        return AIMessage(content=result)



# # Create a session
# session = requests.Session()

# # Define a retry policy with 0 total retries
# retry = Retry(total=0, connect=0, read=0, redirect=0, status=0)

# # Mount the session with the retry policy
# adapter = HTTPAdapter(max_retries=retry)
# session.mount('http://', adapter)
# session.mount('https://', adapter)

class OllamaLLM(metaclass=Singleton):

    def __init__(self, model) -> None:
        self.model = model

    def call_ollama(self, messages: List[SystemMessage | HumanMessage | AIMessage]):
        formatted_messages = []
        for message in messages:
            if isinstance(message, HumanMessage):
                role = "user"
            elif isinstance(message, AIMessage):
                role = "assistant"
            else:  # Defaulting to system role for SystemMessage or any other type
                role = "system"

            # assert isinstance(message.content, str)

            formatted_message = {
                "role": role,
                "content": message.content
            }
            formatted_messages.append(formatted_message)

        response = ollama.chat(model=self.model, messages=formatted_messages)
        content = response['message']['content']

        return AIMessage(content=content)

        # Define the request payload
        # payload = {
        #     "model": self.llm,
        #     "messages": formatted_messages,
        #     "stream": False
        # }

        # # URL of the API endpoint
        # url = "http://localhost:11434/api/chat"

        # # print("payload", payload)
        # # Send the POST request
        # response = requests.post(url, json=payload)

        # # get message.content from response
        # if response.status_code == 200:
        #     result = response.json()
        #     content = result.get("message", {}).get("content")
        #     if content:
        #         return AIMessage(content=content)
        # else:
        #     # Handle error case
        #     raise Exception(f"Error: {response.status_code}, {response.text}")



    def __call__(self, *args: Any, **kwds: Any) -> Any:
        # raise Exception("Not implemented")
        return self.call_ollama(args[0])

# class GroqLLM(metaclass=Singleton):
#         def __init__(self) -> None:
#             self.llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")

#         def __call__(self, *args: Any, **kwds: Any) -> Any:
#             return self.llm.invoke(args[0])

#         def bind(self, *args: Any, **kwds: Any) -> Any:
#             return self.llm.bind(args[0])

def get_llm(model_name, temperature: float =0, max_tokens=None):
    if "lambda" in model_name:
        return ChatOpenAI(
            model=model_name.split("/")[1],
            temperature = temperature,
            max_tokens=max_tokens,
            base_url = "https://api.lambdalabs.com/v1",
            api_key="ADD_KEY_HERE"
        )
    model = "/".join(f"{model_name}".split("/")[-2:])
    if model[0] == 'o':
        print("LOL", model_name)
        return ChatOpenAI(
            model="gpt-4",
            temperature=temperature,
            max_tokens=max_tokens,
        )
    elif model[0] == 'g':
        return ChatGoogleGenerativeAI(
            model_name=model,
            temperature=temperature,
            max_tokens=max_tokens,
        )
    else:
        return ChatTogether(
            model_name=model,
            temperature=temperature,
            max_tokens=max_tokens,
        )
