import logging
import os
import threading
from functools import wraps
import time

from models.base import GenModelBase
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_fixed
from tqdm.contrib.concurrent import thread_map
from utils import configurable_retry

logger = logging.getLogger("text2svg")


def on_retry_error(s):
    e = s.outcome.exception()
    logger.critical(f"give up retrying. error: {e}")
    return ""


def before_retry_sleep(s):
    msg = f"function call error for {s.attempt_number} time(s), will retry... error: {s.outcome.exception()}"
    if s.attempt_number > 10:
        logger.warning(msg)


def configurable_retry(max_attempts):
    def decorator(func):
        @wraps(func)
        @retry(
            wait=wait_fixed(2),
            stop=stop_after_attempt(max_attempts),
            before_sleep=before_retry_sleep,
            retry_error_callback=on_retry_error,
        )
        def sync_wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        return sync_wrapper

    return decorator


class GenModelOpenAI(GenModelBase):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        logger.info(f"Initializing: OpenAI Model {self.model_path}")

        self.client = OpenAI(
            base_url=os.getenv("API_BASE", None),
            api_key=os.getenv("API_KEY", None),
        )
        logger.info(f"OpenAI Config: ")
        logger.info(f"\t- Model: {self.model_path}")
        logger.info(f"\t- {self.max_new_tokens=}")
        logger.info(f"\t- {self.temperature=}")
        logger.info(f"\t- Stop: {self.stop}")

    def preprocess_prompt(self, prompt):
        return prompt

    def generate(self, prompts, enable_tqdm=False):
        if self.is_chat_mode:

            @configurable_retry(5)
            def request(text):
                completion = self.client.chat.completions.create(
                    model=self.model_path,
                    messages=[{"role": "user", "content": text}],
                    max_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    stop=self.stop,
                )
                return completion.choices[0].message.content

            generations = thread_map(request, prompts, max_workers=16, chunksize=4, disable=not enable_tqdm)
        else:
            logger.info(f"\tSending the whole batch. There's won't be a progress bar.")
            responses = self.client.completions.create(
                model=self.model_path,
                prompt=prompts,
                max_tokens=self.max_new_tokens,
                temperature=self.temperature,
                stop=self.stop,
            )
            generations = [x.text for x in responses.choices]

        return generations

    def close(self):
        self.client.close()
        logger.info(f"Cleanup for OpenAI [DONE]")


class GenModelVllmServedOpenAICompatible(GenModelBase):
    MAX_PIXELS = 10000 * 28 * 28  # 10k + 10k
    MAX_INPUTS = 4096

    def __init__(
        self,
        api_num_workers=128,
        **kwargs,
    ):
        super().__init__(**kwargs)

        logger.info(f"Initializing: OpenAI MultiModal Model {self.model_path}")

        self.client = OpenAI(
            base_url=os.getenv("API_BASE", None),
            api_key=os.getenv("API_KEY", None),
        )

        self.infer_args = dict(temperature=0.8, top_p=0.95)
        self.infer_args_extra = dict(repetition_penalty=1.05, top_k=50)
        self.api_num_workers = api_num_workers

        logger.info(f"OpenAI MultiModal Config:")
        logger.info(f"\t- Model: {self.model_path}")
        logger.info(f"\t- {self.max_new_tokens=}")
        logger.info(f"\t- Stop: {self.stop}")

        logger.info(f"Infer Args: {self.infer_args}")
        logger.info(f"[IMPORTANT] {self.api_num_workers = }")

    def preprocess_prompt(self, prompt, list_of_b64_images=None):
        user_content = []
        if list_of_b64_images:
            for base64_image in list_of_b64_images:
                b64_encoded_img = f"data:image/jpeg;base64,{base64_image}"
                user_content.append(
                    {
                        "type": "image_url",
                        "image_url": {"url": b64_encoded_img},
                        "max_pixels": GenModelVllmServedOpenAICompatible.MAX_PIXELS,
                    }
                )

        user_content.append({"type": "text", "text": prompt})

        return user_content

    def generate(self, prompts, enable_tqdm=True):
        logger.info(f"[NOTE] {self.api_num_workers = }")
        stats = {
            "send": 0,
            "failures/length": 0,
            "failures/others": 0,
            "giveup": 0,
            "success": 0,
        }
        stats_lock = threading.Lock()
        max_attempts = 20
        wait_seconds = 2

        def request(user_message, idx):
            for attempt in range(1, max_attempts + 1):
                try:
                    stats["send"] += 1
                    completion = self.client.chat.completions.create(
                        model=self.model_path,
                        messages=[{"role": "user", "content": user_message}],
                        max_tokens=self.max_new_tokens,
                        stop=self.stop,
                        **self.infer_args,
                        extra_body=self.infer_args_extra,
                        extra_headers={"X-DashScope-DataInspection": '{"input":"disable", "output":"disable"}'},
                    )
                    stats["success"] += 1
                    return completion.choices[0].message.content
                except Exception as e:
                    if "is too long to fit into the model" in str(e):
                        logger.warning(f"[WARNING] Request {idx} skipped retry due to input length")
                        with stats_lock:
                            stats["failures/length"] += 1
                        return ""
                    else:
                        logger.warning(f"[WARNING] Request {idx} skipped {str(e)} ({attempt}/{max_attempts})")
                        with stats_lock:
                            stats["failures/others"] += 1

                        time.sleep(wait_seconds)

            logger.critical(f"[CRITICAL] Request {idx} failed after {max_attempts} attempts. Giving up.")
            with stats_lock:
                stats["giveup"] += 1
            return ""

        generations = thread_map(lambda args: request(*args), list(zip(prompts, range(len(prompts)))), max_workers=self.api_num_workers, chunksize=32, disable=not enable_tqdm)
        logger.info(f"STAT = {stats}")
        return generations

    def close(self):
        self.client.close()
        logger.info("Cleanup for OpenAI MultiModal [DONE]")
