import os
import time
from openai import OpenAI
# from vllm import LLM, SamplingParams
# from transformers import AutoModelForCausalLM, AutoTokenizer
from concurrent.futures import ThreadPoolExecutor, as_completed
from utils.args_utils import proprietary_models

# Define a function to create the appropriate LLM model based on the input args
def _create_llm(args):
    """
    Creates and returns an instance of a language model based on the provided arguments.
    Supports models from the Qwen series and proprietary models.
    """
    # Check for specific Qwen models and return corresponding LMChat instance
    if args.model in {"qwen2.5_7B", "qwen2.5_14B", "qwen2.5_32B", "qwen3_4B", "qwen3_8B", "qwen3_32B"}:
        return LMChat(
            device=str(args.device),
            model_name=args.model_path,
            max_new_tokens=2048,
            temperature=0.7,
            top_k=20,
            top_p=0.98,
        )
    
    # Check if the model is proprietary
    if args.model in proprietary_models:
        return OpenAIClient()
    
    # Raise error if model is unknown
    raise ValueError(f"Unknown model: {args.model}")


# LMChat Class to interact with a VLLM model
class LMChat:
    def __init__(
        self,
        device: str,
        model_name: str,
        max_new_tokens: int = 512,
        temperature: float = 1.0,
        top_k: int = None,
        top_p: float = None,
    ):
        """
        Initializes the LMChat class for interacting with VLLM models.
        Arguments:
            device (str): CUDA device identifier (e.g., '0', '1').
            model_name (str): VLLM model path or identifier.
            max_new_tokens (int): Maximum number of new tokens to generate.
            temperature (float): Controls randomness in sampling.
            top_k (int): Top-k sampling parameter.
            top_p (float): Top-p sampling parameter.
        """
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{device}"
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p

        # Set SamplingParams
        params = {
            "max_tokens": self.max_new_tokens,
            "temperature": self.temperature,
            "repetition_penalty": 1.05,
            "min_p": 0,
        }
        if self.top_k is not None:
            params["top_k"] = self.top_k
        if self.top_p is not None:
            params["top_p"] = self.top_p
        self.sampling_params = SamplingParams(**params)

        # Load VLLM model
        self.llm = LLM(
            model=self.model_name,
            dtype="float16",
            device="auto",
            gpu_memory_utilization=0.8,
            enable_chunked_prefill=False,
            tensor_parallel_size=len(device.split(","))
        )

    def generate_response(self, prompt, **kwargs):
        """
        Generate a response based on the input prompt(s).
        Supports single string or list of strings as input.
        Arguments:
            prompt (str or list): User input prompt(s).
            **kwargs: Optional generation parameters to override initial settings.
        Returns:
            str or list: Generated response(s).
        """
        start_time = time.time()

        # Handle input format (string or list)
        if isinstance(prompt, str):
            prompts = [prompt]
            single_input = True
        elif isinstance(prompt, list):
            prompts = prompt
            single_input = False
        else:
            raise ValueError("prompt must be a string or a list of strings")

        # Call VLLM model to generate output
        outputs = self.llm.generate(prompts, self.sampling_params)

        # Extract responses
        responses = [out.outputs[0].text for out in outputs]

        print("Total Time: {:.4f} seconds".format(time.time() - start_time))

        return responses[0] if single_input else responses


# OpenAIClient Class to interact with OpenAI models
class OpenAIClient:
    def __init__(
        self,
        use_concurrent: bool = True,
        max_workers: int | None = None
    ):
        """
        OpenAI client to interact with the OpenAI API.
        Arguments:
            use_concurrent (bool): Whether to use concurrent requests for a list of prompts.
            max_workers (int): Maximum number of concurrent threads. Default is len(self.keys).
        """
        # API keys for OpenAI (replace with actual keys)
        self.keys = [os.getenv('OPENAI_API_KEY')]
        self.index = 0
        self.client = OpenAI(api_key=self.keys[self.index])

        # Concurrent settings
        self.use_concurrent = use_concurrent
        self.max_workers = max_workers

    def switch_key(self):
        """Rotate to the next API key."""
        self.index = (self.index + 1) % len(self.keys)
        self.client = OpenAI(api_key=self.keys[self.index])

    def _generate_single_response(self, prompt: str, model: str, retry: int) -> str:
        """
        Handle the logic for generating a single response from OpenAI API.
        Arguments:
            prompt (str): User input prompt.
            model (str): Model to use.
            retry (int): Number of retry attempts in case of failure.
        Returns:
            str: Generated response.
        """
        messages = [{"role": "user", "content": prompt}]
        for attempt in range(1, retry + 1):
            try:
                response = self.client.chat.completions.create(
                    model=model, messages=messages, temperature=0.8
                )
                return response.choices[0].message.content
            except Exception as e:
                if attempt == retry:
                    raise
                # Switch key and retry
                self.switch_key()
                time.sleep(1)

    def generate_response(
        self,
        prompt: str | list[str],
        model: str = "gpt-4.1-nano",
        retry: int = 3
    ) -> str | list[str]:
        """
        Generate a response from OpenAI API, supporting both single and multiple prompts.
        If the prompt is a list, either concurrent or sequential calls are made.
        Arguments:
            prompt (str | list[str]): User input prompt(s).
            model (str): Model to use for response generation.
            retry (int): Number of retry attempts.
        Returns:
            str | list[str]: Generated response(s).
        """
        if isinstance(prompt, str):
            return self._generate_single_response(prompt, model, retry)

        # For a list of prompts
        prompts = prompt
        workers = self.max_workers or len(self.keys) or len(prompts)

        if self.use_concurrent:
            # Concurrent requests using thread pool
            results: list[str] = [None] * len(prompts)
            with ThreadPoolExecutor(max_workers=workers) as executor:
                future_to_idx = {
                    executor.submit(self._generate_single_response, p, model, retry): i
                    for i, p in enumerate(prompts)
                }
                for fut in as_completed(future_to_idx):
                    idx = future_to_idx[fut]
                    results[idx] = fut.result()
            return results

        else:
            # Sequential calls
            return [
                self._generate_single_response(p, model, retry)
                for p in prompts
            ]