import os

import openai
import torch
import transformers

tokenizer = None
llm = None

def LLM_summaraize(data, **kwargs):
    prompt = """You are an expert Query Optimizer. Your task is to transform a simple, vague original query into a comprehensive, high-quality finegrained query based on your history clarification dialogue.

# Input Information
1. Original Query: {original_query}
2. History Clarification: 
{history_clarification}

# Guidelines
1. Carefully review your history clarification. Identify every constraint, preference, and detail in the conversation.
2. Merge the core topic of the original query with the specific details extracted from the dialogue.
3. Rewrite the prompt into a single, cohesive, and professional query.

# Information Notes
- Perfectly styled as a direct, imperative user command (e.g., starts with "Analyze...", "Create...", "Act as..."). It is written from the user's perspective, is fully actionable, and contains no conversational AI-like phrasing.
- The language should be precise and adapted to the target audience mentioned in the dialogue (if any).
- Ensure there are no logical conflicts.
- Your output language must match the language of the Original Query. That is, if the Original Query is in English, your output `finegrained_query` must also be in English. Conversely, if the Original Query is in Chinese, your output `finegrained_query` must also be in Chinese.

## Output Format
Your response must be a **single paragraph** containing **ONLY** the text of the `finegrained_query`. Do not include explanations."""
    original_query = data["messages"][0]["content"]
    history_clarification = ""
    for message in data["messages"][1:]:
        history_clarification += f"{message['role']}: {message['content']}\n"
    # message_list = ""
    # for message in data["messages"][1:]:
    #     message_list += f"{message['role']}: {message['content']}\n"
    
    messages = [{"role": "user", "content": prompt.format(original_query=original_query, history_clarification=history_clarification)},]
    return _call_local_vllm(messages, **kwargs)



def _call_online_api(messages, **kwargs):
    """Handle OpenAI-style API calls"""
    # Extract API parameters from kwargs or use defaults
    api_key = kwargs.get("api_key", os.environ.get("OPENAI_API_KEY"))
    api_base = kwargs.get("api_base", os.environ.get("OPENAI_BASE_URL"))
    model = kwargs.get("model", "gpt-4.1-2025-04-14")
    temperature = kwargs.get("temperature", 0.7)
    max_tokens = kwargs.get("max_tokens", 1024)

    client = openai.OpenAI(api_key=api_key, base_url=api_base)
    response = client.chat.completions.create(
        model=model, messages=messages, temperature=temperature, max_tokens=max_tokens
    )

    return response.choices[0].message.content


def _call_local_vllm(messages, **kwargs):
    """Handle local vLLM calls"""
    try:
        from vllm import LLM, SamplingParams

        model_path = kwargs.get("model_path")
        if not model_path:
            return "Error: model_path is required for local vLLM inference"

        temperature = kwargs.get("temperature", 0.6)
        max_tokens = kwargs.get("max_tokens", 512)
        top_p = kwargs.get("top_p", 0.9)
        repetition_penalty = kwargs.get("repetition_penalty", 1.1)

        # GPU/CUDA related parameters for vLLM
        # Use visible GPU count if CUDA_VISIBLE_DEVICES is set, otherwise use all GPUs
        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
        if visible_devices:
            # Count the number of visible GPUs
            tensor_parallel_size = kwargs.get("tensor_parallel_size", len(visible_devices.split(",")))
        else:
            tensor_parallel_size = kwargs.get("tensor_parallel_size", torch.cuda.device_count())
        
        gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
        enforce_eager = kwargs.get("enforce_eager", True)  # Use eager mode to avoid compilation issues
        dtype = kwargs.get("dtype", "auto")
        max_model_len = kwargs.get("max_model_len", 4096)

        # Initialize the LLM with the provided model path and GPU parameters
        global llm, tokenizer
        if llm is None:
            try:
                print(f"Initializing vLLM model with tensor_parallel_size={tensor_parallel_size}...")
                llm = LLM(
                    model=model_path,
                    tensor_parallel_size=tensor_parallel_size,
                    gpu_memory_utilization=gpu_memory_utilization,
                    enforce_eager=enforce_eager,
                    dtype=dtype,
                    max_model_len=max_model_len,
                )
                print("vLLM model initialized successfully!")
                
                # Warm up the model with a dummy call to ensure all workers are ready
                if tokenizer is None:
                    tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
                dummy_prompt = tokenizer.apply_chat_template(
                    [{"role": "user", "content": "test"}], 
                    tokenize=False, 
                    add_generation_prompt=True
                )
                dummy_params = SamplingParams(temperature=0.7, max_tokens=1)
                try:
                    _ = llm.generate([dummy_prompt], dummy_params)
                    print("Model warm-up completed!")
                except Exception as warmup_error:
                    print(f"Warning: Warm-up failed but continuing: {warmup_error}")
                
            except Exception as init_error:
                import traceback
                error_msg = f"Error initializing vLLM model: {str(init_error)}\n{traceback.format_exc()}"
                print(error_msg)
                return error_msg

        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            repetition_penalty=repetition_penalty,
        )

        # Convert messages to a single prompt string
        if tokenizer is None:
            try:
                tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
            except Exception as tokenizer_error:
                return f"Error loading tokenizer: {str(tokenizer_error)}"
        
        try:
            prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        except Exception as template_error:
            return f"Error applying chat template: {str(template_error)}"

        try:
            outputs = llm.generate([prompt], sampling_params)
            if not outputs or len(outputs) == 0 or not outputs[0].outputs:
                return "Error: Empty output from vLLM"
            return outputs[0].outputs[0].text
        except Exception as gen_error:
            import traceback
            return f"Error during vLLM generation: {str(gen_error)}\n{traceback.format_exc()}"

    except ImportError:
        return "Error: vLLM library not installed. Please install it with 'pip install vllm'"
    except Exception as e:
        import traceback
        return f"Error in local vLLM inference: {str(e)}\n{traceback.format_exc()}"


def parse_llm_output(output_str):
    """
    Convert the LLM info extraction output string to a list of strings.

    Args:
        output_str (str): String in format "['information: the time of the survey ranges from 2024-2025', 'information: focus on Chinese internet companies']"

    Returns:
        list: List of strings if successful, error message string if failed
    """
    import ast

    result = ast.literal_eval(output_str)
    if not isinstance(result, list):
        return f"Error: Expected a list, got {type(result)}"

    return result
