"""
Simple client for interacting with a VLLM server using the OpenAI API.
"""

from openai import OpenAI
from typing import List, Dict, Any, Optional, Union
import base64
import os
import httpx

class VLLMServerWrapper:
    """
    A wrapper class for VLLM server using the OpenAI API.
    """
    
    def __init__(self, server_url="http://localhost:8000", api_key="token-abc123", model="Qwen/Qwen2.5-VL-32B-Instruct", timeout=600):
        """
        Initialize the VLLM server wrapper.
        
        Args:
            server_url: URL of the VLLM server
            api_key: API key for the VLLM server
            model: Model name to use
        """

        long_timeout_client = httpx.Client(timeout=timeout)
        #long_timeout_async_client = httpx.AsyncClient(timeout=7200)
        self.client = OpenAI(
            base_url=f"{server_url}/v1",
            api_key=api_key,
            http_client=long_timeout_client,
        )
        self.model = model  # Default model
        self.use_vllm = 'localhost' in server_url
        if self.use_vllm:
            self.client = OpenAI(
                base_url=f"{server_url}/v1",
                api_key=api_key,
                http_client=long_timeout_client,
            )
        else:
            self.client = OpenAI(
                base_url=f"{server_url}/v1",
                api_key=api_key,
            )
    
    def generate(self, prompt_or_messages, sampling_params=None):
        """
        Generate text from a prompt or messages.
        
        Args:
            prompt_or_messages: Text prompt or messages array
            sampling_params: Dictionary of sampling parameters (optional)
            
        Returns:
            An object with a similar structure to VLLM's response
        """
        # Set default sampling parameters
        params = {
            "temperature": 0.7,
            "top_p": 0.95,
            "max_tokens": None,
        }
        
        # Update with user-provided parameters
        if sampling_params:
            if hasattr(sampling_params, '__dict__'):
                # Convert object to dictionary
                for attr in dir(sampling_params):
                    if not attr.startswith('_'):
                        value = getattr(sampling_params, attr)
                        if not callable(value):
                            params[attr] = value
            elif isinstance(sampling_params, dict):
                params.update(sampling_params)
            else:
                raise Exception("error")
        
        # Convert prompt to messages format if it's a string
        if isinstance(prompt_or_messages, str):
            messages = [{"role": "user", "content": prompt_or_messages}]
        else:
            messages = prompt_or_messages
        
        # Prepare API call parameters
        api_params = {
            "model": self.model,
            "messages": messages,
            "temperature": params.get("temperature", 0.7),
            "top_p": params.get("top_p", 0.95),
            "max_tokens": params.get("max_tokens", None),
        }
        if "seed" in params:
            api_params["seed"] = params["seed"]
        # vLLM-specific parameters go in extra_body (not available in the OpenAI API)
        if self.use_vllm:
            vllm_params = {}
            vllm_params["top_k"] = params.get("top_k", -1)
            vllm_params["repetition_penalty"] = params.get("repetition_penalty", 1.0)
            api_params["extra_body"] = vllm_params

        # Call the OpenAI API (try 3 times in case it fails)
        try:
            completion = self.client.chat.completions.create(**api_params)
        except: # Try-2
            print("Try-2")
            try:
                completion = self.client.chat.completions.create(**api_params)
            except: # Try-3
                print("Try-3")
                completion = self.client.chat.completions.create(**api_params)
        
        # Extract the generated text
        generated_text = completion.choices[0].message.content
        
        # Create a response object similar to VLLM's
        class Output:
            def __init__(self, text):
                self.text = text
        
        class Response:
            def __init__(self, outputs):
                self.outputs = outputs
        
        # Wrap the response to match VLLM's format
        return [Response([Output(generated_text)])]
