from typing import Optional, Dict
from src.utils.api_lib.huggingface import generation_pipeline_hf
from transformers import AutoModelForCausalLM, AutoTokenizer

try:
    from vllm import LLM 
    from vllm.sampling_params import SamplingParams  
    from vllm.lora.request import LoRARequest 
except ImportError:
    LLM = None  
    SamplingParams = None  
    LoRARequest = None 

from src.utils.api import get_llm_outputs


def is_api_model_auto(model_name):
    return 'gpt' in model_name or 'claude' in model_name


def generate_response(
    chat,
    local_model: Optional[AutoModelForCausalLM] = None,
    local_tokenizer: Optional[AutoTokenizer] = None,
    is_api_model: Optional[str] = 'auto',
    vllm_model: Optional[LLM] = None,
    peft_dir: Optional[str] = None,
    **generation_kwargs,
):
    """Generate a response from the assistant model.
    
    Args:
        chat: Either a single message or list of messages for batch processing
        local_model: Optional HuggingFace model
        local_tokenizer: Optional HuggingFace tokenizer
        is_api_model: Whether to use API model
        vllm_model: Optional vLLM model
        **generation_kwargs: Additional generation parameters
        
    Returns:
        str or List[str] or List[List[str]]: Generated response(s)
    """
    assistant_model_name = generation_kwargs['model']
    is_api_model = is_api_model_auto(assistant_model_name) if is_api_model == 'auto' else is_api_model

    if isinstance(chat, str):
        chat = [{"role": "user", "content": chat}]
    elif isinstance(chat[0], str):
        chat = [[{"role": "user", "content": message}] for message in chat]

    if is_api_model:
        generation_kwargs['max_new_tokens'] = generation_kwargs.pop("max_tokens", 1024)
        generation_kwargs.pop("top_p", None)
        generation_kwargs.pop("top_k", None)
        generation_kwargs.pop("repetition_penalty", None)
        return get_llm_outputs(chat, **generation_kwargs)
    else:
        generation_kwargs.pop("model", None)
        if vllm_model is not None:
            if SamplingParams is None:
                raise ImportError("vllm is required for vLLM inference but is not installed.")
            sampling_params = convert_to_sampling_params(generation_kwargs)
            lora_request = LoRARequest("interactive_adapter", 1, peft_dir) if (LoRARequest and peft_dir) else None
            responses = vllm_model.chat(
                messages=chat,
                sampling_params=sampling_params,
                lora_request=lora_request
            )
            
            results = []
            for response_set in responses:
                if response_set.outputs:
                    results.append(response_set.outputs[0].text)
                else:
                    results.append("")
            return results

        else:
            local_tokenizer.padding_side = "left"
            local_tokenizer.pad_token = local_tokenizer.eos_token
            
            generation_kwargs['max_new_tokens'] = generation_kwargs.pop("max_tokens", 1024)
            responses = generation_pipeline_hf(
                chat,
                local_model,
                local_tokenizer,
                stop_sequence=[],
                **generation_kwargs,
            )
            if not isinstance(responses, list):
                responses = [responses]
            return responses


def convert_to_sampling_params(generation_kwargs: dict):
    """Convert generation kwargs to vllm SamplingParams."""

    if SamplingParams is None:
        raise ImportError("vllm is required for vLLM SamplingParams but is not installed.")

    # Valid sampling parameter keys from SamplingParams class
    valid_params = {
        "n",
        "best_of",
        "presence_penalty",
        "frequency_penalty",
        "repetition_penalty",
        "temperature",
        "top_p",
        "top_k",
        "min_p",
        "seed",
        "stop",
        "stop_token_ids",
        "bad_words",
        "ignore_eos",
        "max_tokens",
        "min_tokens",
        "logprobs",
        "prompt_logprobs",
        "detokenize",
        "skip_special_tokens",
        "spaces_between_special_tokens",
        "truncate_prompt_tokens",
    }

    # Filter valid params and log unmapped ones
    sampling_kwargs = {}
    for key, value in generation_kwargs.items():
        if key in valid_params:
            sampling_kwargs[key] = value
        else:
            print(
                f"Warning: Parameter '{key}' not found in VLLM-supported sampling parameters"
            )

    # Create SamplingParams object
    return SamplingParams.from_optional(**sampling_kwargs)

