import os
import tempfile
import shutil
import torch
import logging
from typing import List, Dict, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams

logger = logging.getLogger(__name__)

class ModelEvaluator:

    def __init__(self,
                 model_path: str,
                 lora_path: Optional[str] = None,
                 torch_dtype: str = "auto",
                 tensor_parallel_size: int = 1,
                 gpu_memory_utilization: float = 0.8):

        self.model_path = model_path
        if "Qwen3" in model_path:
            self.qwen3_model = True
        else:
            self.qwen3_model = False
        if "base" in model_path.lower():
            self.base_model = True
        else:
            self.base_model = False
        print(
            f"\033[91mBased on model path: {self.model_path}, detected qwen3_model={self.qwen3_model} and base_model={self.base_model}\033[0m")

        self.lora_path = lora_path
        self.torch_dtype = torch_dtype
        self.tensor_parallel_size = tensor_parallel_size
        self.gpu_memory_utilization = gpu_memory_utilization
        self.temp_model_dir = None

        self.tokenizer = None
        self.llm = None

        self._prepare_model()

    def _prepare_model(self):
        """Prepare the model; if there is a LoRA, first merge it and save it to a temporary directory."""
        logger.info(f"Preparing model from: {self.model_path}")

        # If there are LoRA weights, they need to be merged first.
        if self.lora_path:
            logger.info(f"Merging LoRA weights from: {self.lora_path}")
            self._merge_lora_and_save()
            model_path_for_vllm = self.temp_model_dir
        else:
            model_path_for_vllm = self.model_path

        # load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_path,
            trust_remote_code=True,
            padding_side="left"
        )
        # set pad_token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # load model with vLLM
        self.llm = LLM(
            model=model_path_for_vllm,
            trust_remote_code=True,
            tensor_parallel_size=self.tensor_parallel_size,
            gpu_memory_utilization=self.gpu_memory_utilization,
            dtype=self.torch_dtype if self.torch_dtype != "auto" else "half",
            # max_model_len=16384
        )

        logger.info("Model loaded successfully with vLLM")

    def _merge_lora_and_save(self):
        """Merge LoRA weights and save them to a temporary directory."""
        logger.info("Merging LoRA weights...")

        # build a temporary directory
        self.temp_model_dir = tempfile.mkdtemp(prefix="vllm_model_")
        logger.info(f"Created temporary model directory: {self.temp_model_dir}")

        # load base model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            self.model_path,
            trust_remote_code=True
        )

        model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            trust_remote_code=True,
            torch_dtype=self.torch_dtype,
            device_map="cpu"
        )

        # load LoRA and merge
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, self.lora_path)
        model = model.merge_and_unload()

        # save merged model and tokenizer to the temporary directory
        logger.info(f"Saving merged model to: {self.temp_model_dir}")
        model.save_pretrained(self.temp_model_dir)
        tokenizer.save_pretrained(self.temp_model_dir)

        # clean up
        del model
        del tokenizer
        torch.cuda.empty_cache()

        logger.info("LoRA weights merged and saved successfully")

    def __del__(self):
        """ Clean up temporary directory if exists """
        if self.temp_model_dir and os.path.exists(self.temp_model_dir):
            logger.info(f"Cleaning up temporary model directory: {self.temp_model_dir}")
            shutil.rmtree(self.temp_model_dir)

    def generate_response(self,
                          conversations: List[List[Dict[str, str]]],
                          max_new_tokens: int = 512,
                          temperature: float = 0.3,
                          top_p: float = 0.8,
                          top_k: int = 20) -> List[str]:
        """Generate responses for a batch of conversations."""
        # template-based formatting
        if self.base_model:
            # For the base model, use a simple formatting
            raise NotImplementedError("Base model formatting not implemented yet.")
            formatted_inputs = []
            for conversation in conversations:
                formatted_text = f"{conversation[0]['content']}\n\n{conversation[1]['content']}"
                formatted_inputs.append(formatted_text)
        else:
            # Batch processing using apply_chat_template
            if self.qwen3_model:
                formatted_inputs = self.tokenizer.apply_chat_template(
                    conversations,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=True
                )
            else:
                formatted_inputs = self.tokenizer.apply_chat_template(
                    conversations,
                    tokenize=False,
                    add_generation_prompt=True
                )

        # Define sampling parameters
        sampling_params = SamplingParams(
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            stop_token_ids=[self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id else None
        )

        # Generate responses using vLLM
        logger.info(f"Generating responses for {len(formatted_inputs)} inputs...")
        outputs = self.llm.generate(formatted_inputs, sampling_params)

        # Extract generated texts
        generated_texts = []
        for output in outputs:
            generated_text = output.outputs[0].text.strip()
            generated_texts.append(generated_text)

        return generated_texts