import logging
import os

logging.basicConfig(
    format="[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger_initialized = {}
import copy as cp
import hashlib
import json
import mimetypes
import os.path as osp
import random as rd
import time
from abc import abstractmethod

import numpy as np
import requests
import validators


# from api import LMUDataRoot, download_file
def LMUDataRoot():
    if "LMUData" in os.environ and osp.exists(os.environ["LMUData"]):
        return os.environ["LMUData"]
    home = osp.expanduser("~")
    root = osp.join(home, "LMUData")
    os.makedirs(root, exist_ok=True)
    return root


def download_file(url, filename=None):
    import urllib.request

    from tqdm import tqdm

    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)

    if filename is None:
        filename = url.split("/")[-1]

    try:
        with DownloadProgressBar(
            unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
        ) as t:
            urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
    except Exception as e:
        import logging

        logging.warning(f"{type(e)}: {e}")
        # Handle Failed Downloads from huggingface.co
        if "huggingface.co" in url:
            url_new = url.replace("huggingface.co", "hf-mirror.com")
            try:
                download_file(url_new, filename)
                return filename
            except Exception as e:
                logging.warning(f"{type(e)}: {e}")
                raise Exception(f"Failed to download {url}")
        else:
            raise Exception(f"Failed to download {url}")

    return filename


def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="w"):
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger

    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    try:
        import torch.distributed as dist

        if dist.is_available() and dist.is_initialized():
            rank = dist.get_rank()
        else:
            rank = 0
    except ImportError:
        rank = 0

    if rank == 0 and log_file is not None:
        file_handler = logging.FileHandler(log_file, file_mode)
        handlers.append(file_handler)

    formatter = logging.Formatter(
        "[%(asctime)s] %(levelname)s - %(name)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s"
    )
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)

    logger_initialized[name] = True
    return logger


def md5(s):
    hash = hashlib.new("md5")
    if osp.exists(s):
        with open(s, "rb") as f:
            for chunk in iter(lambda: f.read(2**20), b""):
                hash.update(chunk)
    else:
        hash.update(s.encode("utf-8"))
    return str(hash.hexdigest())


import base64
import io

from PIL import Image


def decode_base64_to_image(base64_string, target_size=-1):
    image_data = base64.b64decode(base64_string)
    image = Image.open(io.BytesIO(image_data))
    if image.mode in ("RGBA", "P"):
        image = image.convert("RGB")
    if target_size > 0:
        image.thumbnail((target_size, target_size))
    return image


def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
    image = decode_base64_to_image(base64_string, target_size=target_size)
    image.save(image_path)


def parse_file(s):
    if osp.exists(s) and s != ".":
        assert osp.isfile(s)
        suffix = osp.splitext(s)[1].lower()
        mime = mimetypes.types_map.get(suffix, "unknown")
        return (mime, s)
    elif s.startswith("data:image/"):
        # To be compatible with OPENAI base64 format
        content = s[11:]
        mime = content.split(";")[0]
        content = ";".join(content.split(";")[1:])
        dname = osp.join(LMUDataRoot(), "files")
        assert content.startswith("base64,")
        b64 = content[7:]
        os.makedirs(dname, exist_ok=True)
        tgt = osp.join(dname, md5(b64) + ".png")
        decode_base64_to_image_file(b64, tgt)
        return parse_file(tgt)
    elif validators.url(s):
        suffix = osp.splitext(s)[1].lower()
        if suffix in mimetypes.types_map:
            mime = mimetypes.types_map[suffix]
            dname = osp.join(LMUDataRoot(), "files")
            os.makedirs(dname, exist_ok=True)
            tgt = osp.join(dname, md5(s) + suffix)
            download_file(s, tgt)
            return (mime, tgt)
        else:
            return ("url", s)
    else:
        return (None, s)


class BaseAPI:
    allowed_types = ["text", "image"]
    INTERLEAVE = True
    INSTALL_REQ = False

    def __init__(
        self,
        retry=10,
        wait=3,
        system_prompt=None,
        verbose=True,
        fail_msg="Failed to obtain answer via API.",
        **kwargs,
    ):
        """Base Class for all APIs.

        Args:
            retry (int, optional): The retry times for `generate_inner`. Defaults to 10.
            wait (int, optional): The wait time after each failed retry of `generate_inner`. Defaults to 3.
            system_prompt (str, optional): Defaults to None.
            verbose (bool, optional): Defaults to True.
            fail_msg (str, optional): The message to return when failed to obtain answer.
                Defaults to 'Failed to obtain answer via API.'.
            **kwargs: Other kwargs for `generate_inner`.
        """

        self.wait = wait
        self.retry = retry
        self.system_prompt = system_prompt
        self.verbose = verbose
        self.fail_msg = fail_msg
        self.logger = get_logger("ChatAPI")

        if len(kwargs):
            self.logger.info(f"BaseAPI received the following kwargs: {kwargs}")
            self.logger.info("Will try to use them as kwargs for `generate`. ")
        self.default_kwargs = kwargs

    @abstractmethod
    def generate_inner(self, inputs, **kwargs):
        """The inner function to generate the answer.

        Returns:
            tuple(int, str, str): ret_code, response, log
        """
        self.logger.warning("For APIBase, generate_inner is an abstract method. ")
        assert 0, "generate_inner not defined"
        ret_code, answer, log = None, None, None
        # if ret_code is 0, means succeed
        return ret_code, answer, log

    def working(self):
        """If the API model is working, return True, else return False.

        Returns:
            bool: If the API model is working, return True, else return False.
        """
        self.old_timeout = None
        if hasattr(self, "timeout"):
            self.old_timeout = self.timeout
            self.timeout = 120

        retry = 5
        while retry > 0:
            ret = self.generate("hello")
            if ret is not None and ret != "" and self.fail_msg not in ret:
                if self.old_timeout is not None:
                    self.timeout = self.old_timeout
                return True
            retry -= 1

        if self.old_timeout is not None:
            self.timeout = self.old_timeout
        return False

    def check_content(self, msgs):
        """Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.

        Args:
            msgs: Raw input messages.

        Returns:
            str: The message type.
        """
        if isinstance(msgs, str):
            return "str"
        if isinstance(msgs, dict):
            return "dict"
        if isinstance(msgs, list):
            types = [self.check_content(m) for m in msgs]
            if all(t == "str" for t in types):
                return "liststr"
            if all(t == "dict" for t in types):
                return "listdict"
        return "unknown"

    def preproc_content(self, inputs):
        """Convert the raw input messages to a list of dicts.

        Args:
            inputs: raw input messages.

        Returns:
            list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
        """
        if self.check_content(inputs) == "str":
            return [dict(type="text", value=inputs)]
        elif self.check_content(inputs) == "dict":
            assert "type" in inputs and "value" in inputs
            return [inputs]
        elif self.check_content(inputs) == "liststr":
            res = []
            for s in inputs:
                mime, pth = parse_file(s)
                if mime is None or mime == "unknown":
                    res.append(dict(type="text", value=s))
                else:
                    res.append(dict(type=mime.split("/")[0], value=pth))
            return res
        elif self.check_content(inputs) == "listdict":
            for item in inputs:
                assert "type" in item and "value" in item
                mime, s = parse_file(item["value"])
                if mime is None:
                    assert item["type"] == "text", item["value"]
                else:
                    assert mime.split("/")[0] == item["type"]
                    item["value"] = s
            return inputs
        else:
            return None

    # May exceed the context windows size, so try with different turn numbers.
    def chat_inner(self, inputs, **kwargs):
        _ = kwargs.pop("dataset", None)
        while len(inputs):
            try:
                return self.generate_inner(inputs, **kwargs)
            except Exception as e:
                if self.verbose:
                    self.logger.info(f"{type(e)}: {e}")
                inputs = inputs[1:]
                while len(inputs) and inputs[0]["role"] != "user":
                    inputs = inputs[1:]
                continue
        return (
            -1,
            self.fail_msg + ": " + "Failed with all possible conversation turns.",
            None,
        )

    def chat(self, messages, **kwargs1):
        """The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
        assert hasattr(self, "chat_inner"), (
            "The API model should has the `chat_inner` method. "
        )
        for msg in messages:
            assert isinstance(msg, dict) and "role" in msg and "content" in msg, msg
            assert self.check_content(msg["content"]) in [
                "str",
                "dict",
                "liststr",
                "listdict",
            ], msg
            msg["content"] = self.preproc_content(msg["content"])
        # merge kwargs
        kwargs = cp.deepcopy(self.default_kwargs)
        kwargs.update(kwargs1)

        answer = None
        # a very small random delay [0s - 0.5s]
        T = rd.random() * 0.5
        time.sleep(T)

        assert messages[-1]["role"] == "user"

        for i in range(self.retry):
            try:
                ret_code, answer, log = self.chat_inner(messages, **kwargs)
                if ret_code == 0 and self.fail_msg not in answer and answer != "":
                    if self.verbose:
                        print(answer)
                    return answer
                elif self.verbose:
                    if not isinstance(log, str):
                        try:
                            log = log.text
                        except Exception as e:
                            self.logger.warning(
                                f"Failed to parse {log} as an http response: {str(e)}. "
                            )
                    self.logger.info(
                        f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}"
                    )
            except Exception as err:
                if self.verbose:
                    self.logger.error(f"An error occured during try {i}: ")
                    self.logger.error(f"{type(err)}: {err}")
            # delay before each retry
            T = rd.random() * self.wait * 2
            time.sleep(T)

        return self.fail_msg if answer in ["", None] else answer

    def preprocess_message_with_role(self, message):
        system_prompt = ""
        new_message = []

        for data in message:
            assert isinstance(data, dict)
            role = data.pop("role", "user")
            if role == "system":
                system_prompt += data["value"] + "\n"
            else:
                new_message.append(data)

        if system_prompt != "":
            if self.system_prompt is None:
                self.system_prompt = system_prompt
            else:
                self.system_prompt += "\n" + system_prompt
        return new_message

    def generate(self, message, **kwargs1):
        """The main function to generate the answer. Will call `generate_inner` with the preprocessed input messages.

        Args:
            message: raw input messages.

        Returns:
            str: The generated answer of the Failed Message if failed to obtain answer.
        """
        if self.check_content(message) == "listdict":
            message = self.preprocess_message_with_role(message)

        assert self.check_content(message) in ["str", "dict", "liststr", "listdict"], (
            f"Invalid input type: {message}"
        )
        message = self.preproc_content(message)
        assert message is not None and self.check_content(message) == "listdict"
        for item in message:
            assert item["type"] in self.allowed_types, (
                f"Invalid input type: {item['type']}"
            )

        # merge kwargs
        kwargs = cp.deepcopy(self.default_kwargs)
        kwargs.update(kwargs1)

        answer = None
        # a very small random delay [0s - 0.5s]
        T = rd.random() * 0.5
        time.sleep(T)

        for i in range(self.retry):
            try:
                ret_code, answer, log = self.generate_inner(message, **kwargs)
                if ret_code == 0 and self.fail_msg not in answer and answer != "":
                    if self.verbose:
                        print(answer)
                    return answer
                elif self.verbose:
                    if not isinstance(log, str):
                        try:
                            log = log.text
                        except Exception as e:
                            self.logger.warning(
                                f"Failed to parse {log} as an http response: {str(e)}. "
                            )
                    self.logger.info(
                        f"RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}"
                    )
            except Exception as err:
                if self.verbose:
                    self.logger.error(f"An error occured during try {i}: ")
                    self.logger.error(f"{type(err)}: {err}")
            # delay before each retry
            T = rd.random() * self.wait * 2
            time.sleep(T)

        return self.fail_msg if answer in ["", None] else answer

    def message_to_promptimg(self, message, dataset=None):
        assert not self.INTERLEAVE
        model_name = self.__class__.__name__
        import warnings

        warnings.warn(
            f"Model {model_name} does not support interleaved input. "
            "Will use the first image and aggregated texts as prompt. "
        )
        num_images = len([x for x in message if x["type"] == "image"])
        if num_images == 0:
            prompt = "\n".join([x["value"] for x in message if x["type"] == "text"])
            image = None
        elif num_images == 1:
            prompt = "\n".join([x["value"] for x in message if x["type"] == "text"])
            image = [x["value"] for x in message if x["type"] == "image"][0]
        else:
            prompt = "\n".join(
                [x["value"] if x["type"] == "text" else "<image>" for x in message]
            )

            image = [x["value"] for x in message if x["type"] == "image"][0]
        return prompt, image


APIBASES = {
    # 'OFFICIAL': 'https://api.openai.com/v1/chat/completions',
    "OFFICIAL": "https://api.bianxie.ai/v1/chat/completions"
    # 'OFFICIAL': 'https://api.chatanywhere.tech'
}


def encode_image_to_base64(img, target_size=-1, fmt="JPEG"):
    # if target_size == -1, will not do resizing
    # else, will set the max_size ot (target_size, target_size)
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    if target_size > 0:
        img.thumbnail((target_size, target_size))
    img_buffer = io.BytesIO()
    img.save(img_buffer, format=fmt)
    image_data = img_buffer.getvalue()
    ret = base64.b64encode(image_data).decode("utf-8")
    return ret


class OpenAIWrapper(BaseAPI):
    is_api: bool = True

    def __init__(
        self,
        model: str = "gpt-3.5-turbo-0613",
        retry: int = 5,
        wait: int = 5,
        key: str = None,
        verbose: bool = False,
        system_prompt: str = None,
        temperature: float = 0,
        timeout: int = 60,
        api_base: str = None,
        max_tokens: int = 2048,
        img_size: int = 512,
        img_detail: str = "low",
        use_azure: bool = False,
        **kwargs,
    ):
        self.model = model
        self.cur_idx = 0
        self.fail_msg = "Failed to obtain answer via API. "
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.use_azure = use_azure

        if "step" in model:
            env_key = os.environ.get("STEPAI_API_KEY", "")
            if key is None:
                key = env_key
        elif "yi-vision" in model:
            env_key = os.environ.get("YI_API_KEY", "")
            if key is None:
                key = env_key
        elif "internvl2-pro" in model:
            env_key = os.environ.get("InternVL2_PRO_KEY", "")
            if key is None:
                key = env_key
        elif "abab" in model:
            env_key = os.environ.get("MiniMax_API_KEY", "")
            if key is None:
                key = env_key
        elif "moonshot" in model:
            env_key = os.environ.get("MOONSHOT_API_KEY", "")
            if key is None:
                key = env_key
        elif "grok" in model:
            env_key = os.environ.get("XAI_API_KEY", "")
            if key is None:
                key = env_key
        else:
            if use_azure:
                env_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
                assert env_key is not None, (
                    "Please set the environment variable AZURE_OPENAI_API_KEY. "
                )

                if key is None:
                    key = env_key
                assert isinstance(key, str), (
                    "Please set the environment variable AZURE_OPENAI_API_KEY to your openai key. "
                )
            else:
                env_key = os.environ.get("OPENAI_API_KEY", "")
                if key is None:
                    key = env_key
                assert isinstance(key, str) and key.startswith("sk-"), (
                    f"Illegal openai_key {key}. "
                    "Please set the environment variable OPENAI_API_KEY to your openai key. "
                )

        self.key = key
        assert img_size > 0 or img_size == -1
        self.img_size = img_size
        assert img_detail in ["high", "low"]
        self.img_detail = img_detail
        self.timeout = timeout

        super().__init__(
            wait=wait,
            retry=retry,
            system_prompt=system_prompt,
            verbose=verbose,
            **kwargs,
        )

        if use_azure:
            api_base_template = "{endpoint}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
            endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", None)
            assert endpoint is not None, (
                "Please set the environment variable AZURE_OPENAI_ENDPOINT. "
            )
            deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", None)
            assert deployment_name is not None, (
                "Please set the environment variable AZURE_OPENAI_DEPLOYMENT_NAME. "
            )
            api_version = os.getenv("OPENAI_API_VERSION", None)
            assert api_version is not None, (
                "Please set the environment variable OPENAI_API_VERSION. "
            )

            self.api_base = api_base_template.format(
                endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
                deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
                api_version=os.getenv("OPENAI_API_VERSION"),
            )
        else:
            if api_base is None:
                if (
                    "OPENAI_API_BASE" in os.environ
                    and os.environ["OPENAI_API_BASE"] != ""
                ):
                    self.logger.info(
                        "Environment variable OPENAI_API_BASE is set. Will use it as api_base. "
                    )
                    api_base = os.environ["OPENAI_API_BASE"]
                else:
                    api_base = "OFFICIAL"

            assert api_base is not None

            if api_base in APIBASES:
                self.api_base = APIBASES[api_base]
            elif api_base.startswith("http"):
                self.api_base = api_base
            else:
                self.logger.error("Unknown API Base. ")
                raise NotImplementedError

        self.logger.info(f"Using API Base: {self.api_base}; API Key: {self.key}")

    # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
    # content can be a string or a list of image & text
    def prepare_itlist(self, inputs):
        assert np.all([isinstance(x, dict) for x in inputs])
        has_images = np.sum([x["type"] == "image" for x in inputs])
        if has_images:
            content_list = []
            for msg in inputs:
                if msg["type"] == "text":
                    content_list.append(dict(type="text", text=msg["value"]))
                elif msg["type"] == "image":
                    from PIL import Image

                    img = Image.open(msg["value"])
                    b64 = encode_image_to_base64(img, target_size=self.img_size)
                    img_struct = dict(
                        url=f"data:image/jpeg;base64,{b64}", detail=self.img_detail
                    )
                    content_list.append(dict(type="image_url", image_url=img_struct))
        else:
            assert all([x["type"] == "text" for x in inputs])
            text = "\n".join([x["value"] for x in inputs])
            content_list = [dict(type="text", text=text)]
        return content_list

    def prepare_inputs(self, inputs):
        input_msgs = []
        if self.system_prompt is not None:
            input_msgs.append(dict(role="system", content=self.system_prompt))
        assert isinstance(inputs, list) and isinstance(inputs[0], dict)
        assert np.all(["type" in x for x in inputs]) or np.all(
            ["role" in x for x in inputs]
        ), inputs
        if "role" in inputs[0]:
            assert inputs[-1]["role"] == "user", inputs[-1]
            for item in inputs:
                input_msgs.append(
                    dict(
                        role=item["role"], content=self.prepare_itlist(item["content"])
                    )
                )
        else:
            input_msgs.append(dict(role="user", content=self.prepare_itlist(inputs)))
        return input_msgs

    def generate_inner(self, inputs, **kwargs) -> str:
        input_msgs = self.prepare_inputs(inputs)
        temperature = kwargs.pop("temperature", self.temperature)
        max_tokens = kwargs.pop("max_tokens", self.max_tokens)

        # context_window = GPT_context_window(self.model)
        # new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
        # if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
        #     self.logger.warning(
        #         'Less than 100 tokens left, '
        #         'may exceed the context window with some additional meta symbols. '
        #     )
        # if new_max_tokens <= 0:
        #     return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
        # max_tokens = new_max_tokens

        # Will send request if use Azure, dk how to use openai client for it
        if self.use_azure:
            headers = {"Content-Type": "application/json", "api-key": self.key}
        elif "internvl2-pro" in self.model:
            headers = {"Content-Type": "application/json", "Authorization": self.key}
        else:
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.key}",
            }
        payload = dict(
            model=self.model,
            messages=input_msgs,
            max_tokens=max_tokens,
            n=1,
            temperature=temperature,
            **kwargs,
        )
        response = requests.post(
            self.api_base,
            headers=headers,
            data=json.dumps(payload),
            timeout=self.timeout * 1.1,
        )
        ret_code = response.status_code
        ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
        answer = self.fail_msg
        try:
            resp_struct = json.loads(response.text)
            answer = resp_struct["choices"][0]["message"]["content"].strip()
        except Exception as err:
            if self.verbose:
                self.logger.error(f"{type(err)}: {err}")
                self.logger.error(
                    response.text if hasattr(response, "text") else response
                )

        return ret_code, answer, response

    def get_image_token_len(self, img_path, detail="low"):
        import math

        if detail == "low":
            return 85

        im = Image.open(img_path)
        height, width = im.size
        if width > 1024 or height > 1024:
            if width > height:
                height = int(height * 1024 / width)
                width = 1024
            else:
                width = int(width * 1024 / height)
                height = 1024

        h = math.ceil(height / 512)
        w = math.ceil(width / 512)
        total = 85 + 170 * h * w
        return total

    def get_token_len(self, inputs) -> int:
        import tiktoken

        try:
            enc = tiktoken.encoding_for_model(self.model)
        except Exception as err:
            if "gpt" in self.model.lower():
                if self.verbose:
                    self.logger.warning(f"{type(err)}: {err}")
                enc = tiktoken.encoding_for_model("gpt-4")
            else:
                return 0
        assert isinstance(inputs, list)
        tot = 0
        for item in inputs:
            if "role" in item:
                tot += self.get_token_len(item["content"])
            elif item["type"] == "text":
                tot += len(enc.encode(item["value"]))
            elif item["type"] == "image":
                tot += self.get_image_token_len(item["value"], detail=self.img_detail)
        return tot
