from typing import List, Dict
from vllm import LLM, SamplingParams
from abc import ABC, abstractmethod


class InferenceEngine(ABC):
    @abstractmethod
    def generate(self, prompts: List[str], **kwargs) -> List[str]:
        pass


class HuggingFaceEngine(InferenceEngine):
    def __init__(
        self, 
        model, 
        tokenizer, 
        device, 
        tokenizer_args={
            'return_tensors': "pt",
            'padding': True,
            'truncation': True,
            'max_length': 4096,
            'padding_side': 'left'
            }, 
        chat_template_args={
            'tokenize': True, 
            'add_generation_prompt': True
        },
    ):
        """
        Initialize with HuggingFace model and tokenizer.
        
        Args:
            model: Loaded HuggingFace model
            tokenizer: Loaded HuggingFace tokenizer
            device: Device to run on ('cuda', 'cpu', etc.)
            max_length: Maximum token length for inputs
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.tokenizer_args = tokenizer_args
        self.chat_template_args = chat_template_args

    def generate(self, prompts: List[List[Dict[str, str]]], **kwargs) -> List[str]:
        """
        Generate responses for multiple chat-formatted prompts.
        
        Args:
            prompts: List of chat messages (system/user/assistant dicts)
            **kwargs: Additional generation parameters
            
        Returns:
            List of generated responses (only new tokens)
        """
        # Format prompts using chat template
        formatted_prompts = [
            self.tokenizer.apply_chat_template(
                prompt,
                **self.chat_template_args
            ) for prompt in prompts
        ]
        
        # Tokenize with padding/truncation
        inputs = self.tokenizer(
            formatted_prompts,
            **self.tokenizer_args
        ).to(self.device)
        
        # Generate responses
        outputs = self.model.generate(**inputs, **kwargs)
        
        # Decode only the generated portions
        responses = []
        for gen_ids, input_ids in zip(outputs, inputs['input_ids']):
            response = self.tokenizer.decode(
                gen_ids[len(input_ids):], 
                skip_special_tokens=True
            ).strip()
            responses.append(response)
        
        return responses


class VLLMEngine(InferenceEngine):
    def __init__(
        self, 
        model_name_or_path: str, 
        max_length: int = 4096,
        chat_template_args={
            'tokenize': False, 
            'add_generation_prompt': True
        },
        **kwargs):
        """
        Initialize with vLLM backend.
        
        Args:
            model_name_or_path: Model name or path
            max_length: Maximum token length for inputs
            **kwargs: Additional vLLM LLM initialization params
        """
        self.llm = LLM(model=model_name_or_path, **kwargs)
        self.max_length = max_length
        self.tokenizer = self.llm.get_tokenizer()
        self.chat_template_args = chat_template_args

    def generate(self, prompts: List[List[Dict[str, str]]], **kwargs) -> List[str]:
        """
        Generate responses using vLLM's optimized backend.
        
        Args:
            prompts: List of chat messages (system/user/assistant dicts)
            **kwargs: Additional SamplingParams
            
        Returns:
            List of generated responses (only new tokens)
        """
        # Convert generation kwargs to vLLM's SamplingParams
        sampling_params = SamplingParams(**kwargs)
        
        # Format prompts using the same chat template
        formatted_prompts = [
            self.tokenizer.apply_chat_template(
                prompt,
                **self.chat_template_args
            ) for prompt in prompts
        ]
        
        # Generate with vLLM
        outputs = self.llm.generate(formatted_prompts, sampling_params, use_tqdm=False)
        
        # Extract just the generated text (vLLM already handles this)
        return [output.outputs[0].text.strip() for output in outputs]


class APIEngine(InferenceEngine):
    def __init__(self, api_url: str, api_key: str = None):
        """
        Initialize API client.
        
        Args:
            api_url: Base URL for the API endpoint
            api_key: Optional authentication key
        """
        self.api_url = api_url
        self.api_key = api_key

    def generate(self, prompts: List[List[Dict[str, str]]], **kwargs) -> List[str]:
        """
        Generate responses via API calls.
        
        Args:
            prompts: List of chat messages
            **kwargs: Additional API parameters
            
        Returns:
            List of generated responses
        """
        # Implement API request logic
        raise NotImplementedError("APIEngine requires implementation")