import base64
import io
import numpy as np
from PIL import Image
from openai import OpenAI
from constants import E2C


class QwenVLM:
    def __init__(self, api_base="", api_key="", system_prompt=None):
        """
        Initialize the QwenVL client
        Args:
            api_base (str): The base URL for the vLLM API
            api_key (str): The API key for authentication
            system_prompt (str, optional): System prompt to guide the model's behavior
        """
        self.client = OpenAI(base_url=api_base, api_key=api_key)
        self.system_prompt = system_prompt if system_prompt else "You are a helpful visual assistant that can understand images and answer questions about them accurately."

    def _convert_numpy_to_base64(self, image_array):
        """Convert numpy array to base64 string"""
        # Check if image_array is None
        if image_array is None:
            raise ValueError("Image array cannot be None")

        # Ensure the image is in uint8 format
        if image_array.dtype != np.uint8:
            image_array = (image_array).astype(np.uint8)

        # Convert numpy array to PIL Image
        image = Image.fromarray(image_array)

        # Convert PIL Image to base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()

        return img_str

    def ask_single_image(self, image_array, goal_category, question=None):
        """
        Ask a question about a single image

        Args:
            image_array (np.ndarray): RGB image as numpy array (H,W,3) with values 0-255
            question (str): The question to ask about the image

        Returns:
            str: The model's response
        """
        if not question:
            prompt = f"""
            You are a senior indoor navigation expert currently performing an Object Goal Navigation task. Your starting position is randomly 
            initialized in an unfamiliar environment. The goal category **{E2C[goal_category]}** has been specified. You are required to perceive the environment 
            using RGB visual inputs and autonomously navigate to any instance of the **{E2C[goal_category]}** category. 
            Now RGB observations from a candidate region are provided. Based on the visual information, 
            evaluate the goal presence probability. You should consider but not be limited to:  
            1. **Don't just focus on whether the goal object appears in the picture, consider 
            co-occurrence probability between goal object and other objects**  
            2. **Deduce region type from visual features to evaluate goal presence likelihood** 
            3. **Assess exploration worthiness if the region is only partially visible**  
            4. **Other information and principles that you consider useful** 
            5. **Assign higher probability when goal object appears in view** 
            Finally, return a probability score within **[0,1]** (1 indicates highest likelihood) in strict **JSON** format: {"score": your_float_score}   
            """
        else:
            prompt = question

        # Convert image to base64
        img_base64 = self._convert_numpy_to_base64(image_array)

        # Prepare the messages
        messages = [
            {"role": "system", "content": self.system_prompt},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{img_base64}"
                        }
                    },
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]

        # Make the API call
        response = self.client.chat.completions.create(
            model="Qwen/Qwen2.5-VL-7B-Instruct",
            messages=messages,
            max_tokens=512
        )

        return response.choices[0].message.content

    def ask_multiple_images(self, image_arrays, goal_category, question=None):
        """
        Ask a question about multiple images

        Args:
            image_arrays (list): List of RGB images as numpy arrays (H,W,3) with values 0-255
            question (str): The question to ask about the images

        Returns:
            str: The model's response
        """
        prompt = f"""
            You are a senior indoor navigation expert currently performing an Object Goal Navigation task. Your starting position is randomly 
            initialized in an unfamiliar environment. The goal category **{E2C[goal_category]}** has been specified. You are required to perceive the environment 
            using RGB visual inputs and autonomously navigate to any instance of the **{E2C[goal_category]}** category. 
            Now RGB observations from a candidate region are provided. Based on the visual information, 
            evaluate the goal presence probability. You should consider but not be limited to:  
            1. **Don't just focus on whether the goal object appears in the picture, consider 
            co-occurrence probability between goal object and other objects**  
            2. **Deduce region type from visual features to evaluate goal presence likelihood** 
            3. **Assess exploration worthiness if the region is only partially visible**  
            4. **Other information and principles that you consider useful** 
            5. **Assign higher probability when goal object appears in view** 
            Finally, return a probability score within **[0,1]** (1 indicates highest likelihood) in strict **JSON** format: {"score": your_float_score}   
            """

        user_content = []
        for image_array in image_arrays:
            img_base64 = self._convert_numpy_to_base64(image_array)
            user_content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{img_base64}"
                }
            })

        user_content.append({
            "type": "text",
            "text": prompt
        })

        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": user_content}
        ]

        # Make the API call
        response = self.client.chat.completions.create(
            model="Qwen/Qwen2.5-VL-7B-Instruct",
            messages=messages,
            max_tokens=512
        )

        return response.choices[0].message.content


def image_to_numpy(image_path: str) -> np.ndarray:
    try:
        # Open image and force convert to RGB format (compatible with RGBA/grayscale cases)
        img = Image.open(image_path).convert('RGB')
        img_array = np.array(img)

        # Verify array shape is (H, W, 3)
        if len(img_array.shape) != 3 or img_array.shape[2] != 3:
            raise ValueError("Input image must be in color format (3 RGB channels)")

        return img_array

    except FileNotFoundError:
        raise FileNotFoundError(f"Image file not found: {image_path}")
    except Exception as e:
        raise RuntimeError(f"Image conversion failed: {str(e)}")