import gc
import logging

import torch
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
from models.base import GenModelBase

try:
    from qwen_vl_utils import process_vision_info
except ImportError:
    print(f"[WARNING] package `qwen_vl_utils` not found.")


logger = logging.getLogger("text2svg")


class GenModelVllm(GenModelBase):
    MAX_INPUTS = 4096

    def __init__(self, tp, **kwargs):
        super().__init__(**kwargs)
        logger.info(f"Initializing: vLLM Model {self.model_path}")

        self.model = LLM(
            model=self.model_path,
            tensor_parallel_size=tp,
            gpu_memory_utilization=0.95,
            enforce_eager=True,
            trust_remote_code=True,
            # max_model_len=65535,
        )
        # self.model = None
        self.sampling_params = SamplingParams(
            temperature=self.temperature,
            max_tokens=self.max_new_tokens,
            stop=self.stop,
            include_stop_str_in_output=True,
            **self.inference_args,
        )
        logger.info(f"vLLM model init: {self.model_path}")
        logger.info(f"{self.sampling_params = }")

        if self.is_chat_mode:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            logger.info(f"Tokenizer loaded.")
        else:
            logger.info(f"No Tokenizer Loaded")

    def preprocess_prompt(self, prompt, force=None):
        if self.is_chat_mode:
            prompt = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
            if force is not None and force != "normal":
                if force == "thinking":
                    if not prompt.rstrip().endswith("<think>"):
                        prompt = f"{prompt.strip()}\n<think>\n"
                if force == "no_thinking":
                    if not prompt.rstrip().endswith("</think>"):
                        prompt = f"{prompt.strip()}\n<think>\n\n</think>\n\n"

        return prompt

    def generate(self, prompts, enable_tqdm=False):
        max_input_size = GenModelVllm.MAX_INPUTS

        all_generations = []
        for batch_idx, i in enumerate(range(0, len(prompts), max_input_size)):
            upper = max(len(prompts), i + max_input_size)
            logger.info(f"Infer batch[{batch_idx}] {i} -> {upper}")
            batch = prompts[i : i + max_input_size]
            vllm_outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=enable_tqdm)
            batch_generations = [x.outputs[0].text for x in vllm_outputs]
            all_generations.extend(batch_generations)

        return all_generations

    def close(self):
        logger.info(f"Try cleanup...")
        try:
            destroy_model_parallel()
            destroy_distributed_environment()
            del self.model.llm_engine.model_executor
            del self.model
            gc.collect()
            torch.cuda.empty_cache()
        except Exception as e:
            logger.info(f"Error when cleanup: {str(e)}")

        logger.info(f"Cleanup for vllm [DONE]")


class GenModelVllmMultiModal(GenModelBase):
    MAX_PIXELS = 10000 * 28 * 28  # 10k + 10k
    MAX_INPUTS = 4096

    def __init__(self, tp, **kwargs):
        super().__init__(**kwargs)
        logger.info(f"Initializing: vLLM Model (Multi-Modal) {self.model_path}")

        self.model = LLM(
            model=self.model_path,
            tensor_parallel_size=tp,
            enforce_eager=True,
            trust_remote_code=True,
            limit_mm_per_prompt={"image": 2},
            # max_model_len=65535,
        )
        # self.model = None
        self.sampling_params = SamplingParams(
            max_tokens=self.max_new_tokens,
            stop=self.stop,
            temperature=0.8,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.05,
        )
        logger.info(f"vLLM model init: {self.model_path}")
        logger.info(f"{self.sampling_params = }")

        if self.is_chat_mode:
            self.tokenizer = AutoProcessor.from_pretrained(self.model_path, use_fast=True, max_pixels=GenModelVllmMultiModal.MAX_PIXELS)
            logger.info("Tokenizer loaded for multi-modal usage.")
        else:
            raise ValueError(f"VL model doesn't support base mode.")

    def _prepare_llm_input(self, history):
        final_prompt = self.tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)

        image_inputs, _ = process_vision_info(history)
        mm_data = {}
        if image_inputs:
            mm_data["image"] = image_inputs

        llm_input = {
            "prompt": final_prompt,
            "multi_modal_data": mm_data,
        }
        return llm_input

    def preprocess_prompt(self, prompt, list_of_b64_images=None):
        user_content = []
        if list_of_b64_images:
            for base64_image in list_of_b64_images:
                b64_encoded_img = f"data:image/jpeg;base64,{base64_image}"
                user_content.append({"type": "image", "image": b64_encoded_img})

        user_content.append({"type": "text", "text": prompt})

        history = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": user_content},
        ]
        return self._prepare_llm_input(history)

    def generate(self, prompts, enable_tqdm=False):
        max_input_size = GenModelVllm.MAX_INPUTS

        all_generations = []
        for batch_idx, i in enumerate(range(0, len(prompts), max_input_size)):
            upper = max(len(prompts), i + max_input_size)
            logger.info(f"Infer batch[{batch_idx}] {i} -> {upper}")
            batch = prompts[i : i + max_input_size]
            vllm_outputs = self.model.generate(batch, sampling_params=self.sampling_params, use_tqdm=enable_tqdm)
            batch_generations = [x.outputs[0].text for x in vllm_outputs]
            all_generations.extend(batch_generations)

        return all_generations

    def close(self):
        logger.info("Try cleanup (multi-modal vLLM)...")
        try:
            destroy_model_parallel()
            destroy_distributed_environment()
            del self.model.llm_engine.model_executor
            del self.model
            gc.collect()
            torch.cuda.empty_cache()
        except Exception as e:
            logger.info(f"Error when cleanup: {str(e)}")

        logger.info("Cleanup for vLLM (Multi-Modal) [DONE]")
