"""
Model Generation Utilities

This module provides efficient batched generation utilities for language models.
Supports both full models and PEFT adapter models.

Includes response caching functionality to avoid redundant model calls.
"""

import torch
import os
import json
import pickle
import hashlib
import atexit
from typing import List, Union, Dict, Tuple, Optional, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel


class ResponseCache:
    """
    Disk-based cache for model responses.

    Caches responses per model checkpoint to avoid redundant generation calls.
    Cache files are stored alongside model checkpoints as pickle files.
    """

    def __init__(self, model_path: str):
        """
        Initialize cache for a specific model.

        Args:
            model_path: Path to model checkpoint (e.g., '/path/to/models/PPO-QA3A/checkpoint-290')
        """
        self.model_path = model_path
        self.cache_path = self._get_cache_path(model_path)
        self.cache: Dict[str, str] = {}
        self._tokenizer = None  # Lazy-loaded tokenizer for prompt normalization
        self._load_cache()

    @property
    def tokenizer(self):
        """Lazy-load tokenizer for prompt normalization."""
        if self._tokenizer is None:
            self._tokenizer = self._load_tokenizer()
        return self._tokenizer

    def _load_tokenizer(self):
        """
        Load tokenizer from model path for prompt normalization.

        Handles both adapter models and full models, similar to
        apply_chat_template_to_prompt().
        """
        adapter_config_path = os.path.join(self.model_path, 'adapter_config.json')
        is_adapter = os.path.exists(adapter_config_path)

        if is_adapter:
            # Load adapter config to get base model
            with open(adapter_config_path, 'r') as f:
                adapter_config = json.load(f)
            base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')

            # Check if tokenizer files exist in the adapter directory
            tokenizer_config_path = os.path.join(self.model_path, 'tokenizer_config.json')
            if os.path.exists(tokenizer_config_path):
                tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
            else:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
        else:
            # Load tokenizer from full model path
            tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)

        # Set pad token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        return tokenizer

    @staticmethod
    def _get_cache_path(model_path: str) -> str:
        """
        Get cache file path from model path.

        For model_path='/path/to/models/PPO-QA3A/checkpoint-290':
        Returns: '/path/to/models/PPO-QA3A/checkpoint-290-cache.pkl'
        """
        # Normalize path (remove trailing slash if any)
        model_path = model_path.rstrip('/')
        # Create cache filename based on checkpoint name
        cache_filename = os.path.basename(model_path) + '-cache.pkl'
        # Save in parent directory of checkpoint
        parent_dir = os.path.dirname(model_path)
        return os.path.join(parent_dir, cache_filename)

    def normalize_prompt(self, prompt: Union[str, List[Dict[str, str]]]) -> str:
        """
        Normalize prompt to a consistent string format for cache key.

        If prompt is already a string, use it directly (assumed to be pre-formatted).
        If prompt is a dialogue list (e.g., [{"role": "user", "content": "..."}]),
        apply the chat template to convert it to a formatted string.

        This ensures that the same conversation has the same cache key regardless
        of whether it was passed as a pre-formatted string or as a dialogue list.

        Args:
            prompt: Either a string or list of message dicts

        Returns:
            Normalized string representation of the prompt
        """
        if isinstance(prompt, str):
            # String prompts are assumed to be already formatted
            return prompt
        elif isinstance(prompt, list):
            # Apply chat template to convert dialogue format to string
            # This matches the behavior in tokenize_batch() for list prompts
            formatted = self.tokenizer.apply_chat_template(
                prompt,
                tokenize=False,
                add_generation_prompt=True
            )
            return formatted
        else:
            return str(prompt)

    def _load_cache(self) -> None:
        """Load cache from disk if it exists."""
        if os.path.exists(self.cache_path):
            try:
                with open(self.cache_path, 'rb') as f:
                    self.cache = pickle.load(f)
                print(f"Loaded {len(self.cache)} cached responses from {self.cache_path}")
            except Exception as e:
                print(f"Warning: Failed to load cache from {self.cache_path}: {e}")
                self.cache = {}
        else:
            self.cache = {}

    def _save_cache(self) -> None:
        """Save cache to disk."""
        try:
            # Ensure parent directory exists
            os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
            with open(self.cache_path, 'wb') as f:
                pickle.dump(self.cache, f)
        except Exception as e:
            print(f"Warning: Failed to save cache to {self.cache_path}: {e}")

    def get(self, prompt: Union[str, List[Dict[str, str]]]) -> Optional[str]:
        """
        Get cached response for a prompt.

        Args:
            prompt: The input prompt

        Returns:
            Cached response or None if not found
        """
        key = self.normalize_prompt(prompt)
        return self.cache.get(key)

    def set(self, prompt: Union[str, List[Dict[str, str]]], response: str) -> None:
        """
        Cache a response for a prompt.

        Args:
            prompt: The input prompt
            response: The generated response
        """
        key = self.normalize_prompt(prompt)
        self.cache[key] = response

    def has(self, prompt: Union[str, List[Dict[str, str]]]) -> bool:
        """Check if prompt is in cache."""
        key = self.normalize_prompt(prompt)
        return key in self.cache

    def save(self) -> None:
        """Save cache to disk."""
        self._save_cache()

    def __len__(self) -> int:
        return len(self.cache)

# Global cache registry to avoid reloading cache files repeatedly
_cache_registry: Dict[str, ResponseCache] = {}

def get_response_cache(model_path: str) -> ResponseCache:
    """
    Get or create a ResponseCache for the given model path.

    Uses a global registry to avoid creating multiple cache objects for the same model.

    Args:
        model_path: Path to model checkpoint

    Returns:
        ResponseCache instance for the model
    """
    # Normalize path
    model_path = os.path.normpath(model_path)

    if model_path not in _cache_registry:
        _cache_registry[model_path] = ResponseCache(model_path)

    return _cache_registry[model_path]

def clear_cache_registry() -> None:
    """Clear the global cache registry (useful for testing or memory management)."""
    global _cache_registry
    _cache_registry = {}


class ChatTemplateCache:
    """
    Disk-based cache for chat template formatted prompts.

    Caches formatted prompts per model to avoid redundant tokenizer loading
    and chat template application. Reuses tokenizer loading pattern from ResponseCache.
    """

    def __init__(self, model_path: str):
        self.model_path = model_path
        self.cache_path = self._get_cache_path(model_path)
        self.cache: Dict[str, str] = {}
        self._tokenizer = None
        self._load_cache()

    @property
    def tokenizer(self):
        """Lazy-load tokenizer."""
        if self._tokenizer is None:
            self._tokenizer = self._load_tokenizer()
        return self._tokenizer

    def _load_tokenizer(self):
        """Load tokenizer from model path (same pattern as ResponseCache)."""
        adapter_config_path = os.path.join(self.model_path, 'adapter_config.json')
        is_adapter = os.path.exists(adapter_config_path)

        if is_adapter:
            with open(adapter_config_path, 'r') as f:
                adapter_config = json.load(f)
            base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')
            tokenizer_config_path = os.path.join(self.model_path, 'tokenizer_config.json')
            if os.path.exists(tokenizer_config_path):
                tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
            else:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
        else:
            tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        return tokenizer

    @staticmethod
    def _get_cache_path(model_path: str) -> str:
        """Get cache file path (uses different suffix than ResponseCache)."""
        model_path = model_path.rstrip('/')
        cache_filename = os.path.basename(model_path) + '-chat-template-cache.pkl'
        parent_dir = os.path.dirname(model_path)
        return os.path.join(parent_dir, cache_filename)

    def _load_cache(self) -> None:
        """Load cache from disk if it exists."""
        if os.path.exists(self.cache_path):
            try:
                with open(self.cache_path, 'rb') as f:
                    self.cache = pickle.load(f)
                print(f"Loaded {len(self.cache)} cached chat templates from {self.cache_path}")
            except Exception as e:
                print(f"Warning: Failed to load chat template cache: {e}")
                self.cache = {}

    def save(self) -> None:
        """Save cache to disk."""
        try:
            os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
            with open(self.cache_path, 'wb') as f:
                pickle.dump(self.cache, f)
        except Exception as e:
            print(f"Warning: Failed to save chat template cache: {e}")

    def get_formatted(self, prompt: Union[str, List[Dict[str, str]]], max_length: Optional[int] = 2000) -> str:
        """
        Get formatted prompt, using cache if available.

        Args:
            prompt: Either a string or list of message dicts
            max_length: Maximum character length (truncates from left if exceeded)

        Returns:
            Formatted prompt string
        """
        # Create cache key from prompt
        cache_key = str(prompt)

        if cache_key in self.cache:
            formatted = self.cache[cache_key]
        else:
            # Apply chat template
            if isinstance(prompt, list):
                formatted = self.tokenizer.apply_chat_template(
                    prompt, tokenize=False, add_generation_prompt=False
                )
            else:
                formatted = self.tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    tokenize=False, add_generation_prompt=False
                )
            self.cache[cache_key] = formatted

        # Apply truncation
        if max_length is not None and len(formatted) > max_length:
            formatted = formatted[-max_length:]

        return formatted


# Global registry for chat template caches
_chat_template_cache_registry: Dict[str, ChatTemplateCache] = {}


def get_chat_template_cache(model_path: str) -> ChatTemplateCache:
    """Get or create a ChatTemplateCache for the given model path."""
    model_path = os.path.normpath(model_path)
    if model_path not in _chat_template_cache_registry:
        _chat_template_cache_registry[model_path] = ChatTemplateCache(model_path)
    return _chat_template_cache_registry[model_path]


def save_all_chat_template_caches() -> None:
    """Save all chat template caches to disk."""
    for cache in _chat_template_cache_registry.values():
        cache.save()


# Register atexit handler to save chat template caches on program exit
atexit.register(save_all_chat_template_caches)


def generate_responses_batched(
    model_path: Optional[str] = None,
    prompts: List[Union[str, List[Dict[str, str]]]] = None,
    model: Optional[object] = None,
    tokenizer: Optional[object] = None,
    max_new_tokens: int = 512,
    batch_size: int = 8,
    temperature: float = 0.7,
    top_p: float = 0.9,
    device: str = "auto",
    return_model_and_tokenizer: bool = False,
    use_cache: bool = True
) -> Union[List[str], Tuple[List[str], object, object]]:
    """
    Generate responses for multiple prompts using batched inference.

    This function efficiently processes multiple prompts in batches to improve
    throughput compared to generating responses one by one.

    Supports disk-based caching to avoid regenerating responses for the same prompts.
    Cache is stored per model checkpoint.

    Args:
        model_path: Path to the model checkpoint (can be adapter or full model).
                   Either model_path OR (model, tokenizer) must be provided.
        prompts: List of prompts - either strings or chat message format
        model: Pre-loaded model object (alternative to model_path)
        tokenizer: Pre-loaded tokenizer object (must be provided with model)
        max_new_tokens: Maximum number of new tokens to generate per response
        batch_size: Number of prompts to process in each batch
        temperature: Sampling temperature for generation
        top_p: Top-p (nucleus) sampling parameter
        device: Device to use for model ("auto", "cuda", "cpu")
        return_model_and_tokenizer: If True, return loaded model and tokenizer
                                   along with responses (useful for caching)
        use_cache: If True, use disk-based caching to avoid regenerating responses
                  for the same prompts. Cache is only used when model_path is provided.
                  Default: True.

    Returns:
        List of generated responses, or tuple of (responses, model, tokenizer)
        if return_model_and_tokenizer is True

    Example:
        >>> # Using model path with caching (default)
        >>> responses = generate_responses_batched(
        ...     model_path="/path/to/model",
        ...     prompts=["What is AI?", "Explain ML"],
        ...     batch_size=2
        ... )

        >>> # Disable caching explicitly
        >>> responses = generate_responses_batched(
        ...     model_path="/path/to/model",
        ...     prompts=["What is AI?"],
        ...     use_cache=False
        ... )
    """
    if not prompts:
        if return_model_and_tokenizer:
            return [], model, tokenizer
        return []

    # Initialize cache if caching is enabled and model_path is provided
    cache = None
    if use_cache and model_path is not None:
        cache = get_response_cache(model_path)
        print("Loaded response cache.")

    # Check cache for existing responses
    cached_responses = {}  # Maps original index to cached response
    prompts_to_generate = []  # Prompts that need generation
    prompts_to_generate_indices = []  # Original indices of prompts to generate

    if cache is not None:
        for idx, prompt in enumerate(prompts):
            cached_response = cache.get(prompt)
            if cached_response is not None:
                cached_responses[idx] = cached_response
            else:
                prompts_to_generate.append(prompt)
                prompts_to_generate_indices.append(idx)

        if cached_responses:
            print(f"Found {len(cached_responses)} cached responses, need to generate {len(prompts_to_generate)}")
    else:
        prompts_to_generate = prompts
        prompts_to_generate_indices = list(range(len(prompts)))

    # If all responses are cached, return them directly
    
    if not prompts_to_generate:
        responses = [cached_responses[i] for i in range(len(prompts))]
        if return_model_and_tokenizer:
            return responses, model, tokenizer
        return responses

    # Either load model/tokenizer from path OR use provided ones
    if model is None or tokenizer is None:
        if model_path is None:
            raise ValueError("Either model_path OR (model, tokenizer) must be provided")
        # Load model and tokenizer from path
        model, tokenizer = load_model_and_tokenizer(model_path, device)
        cleanup_after = not return_model_and_tokenizer
    else:
        # Use provided model and tokenizer
        if model_path is not None and not use_cache:
            print("Warning: Both model_path and model/tokenizer provided. Using provided model/tokenizer.")
        cleanup_after = False

    # Ensure padding token is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    generated_responses = []
    total_batches = (len(prompts_to_generate) + batch_size - 1) // batch_size

    print(f"Generating {len(prompts_to_generate)} responses in {total_batches} batches...")

    # Process prompts in batches
    for batch_idx in range(0, len(prompts_to_generate), batch_size):
        batch_prompts = prompts_to_generate[batch_idx:batch_idx + batch_size]
        batch_size_actual = len(batch_prompts)

        # Tokenize batch
        input_ids, attention_masks = tokenize_batch(
            batch_prompts,
            tokenizer,
            max_length=1024
        )

        # Move to device
        input_ids = input_ids.to(model.device)
        attention_masks = attention_masks.to(model.device)

        # Generate responses
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_masks,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
            )

        # Decode responses (only the generated part)
        for i in range(batch_size_actual):
            # For left-padded inputs, find where the actual input starts
            # Count the number of padding tokens at the beginning
            padding_length = (attention_masks[i] == 0).sum().item()

            # The actual input length (excluding padding)
            actual_input_length = (attention_masks[i] == 1).sum().item()

            # The position where generation starts is padding_length + actual_input_length
            generation_start = padding_length + actual_input_length

            # Decode only the generated tokens (everything after the original input)
            response = tokenizer.decode(
                outputs[i][generation_start:],
                skip_special_tokens=True
            )
            generated_responses.append(response)

            # Cache the new response
            if cache is not None:
                prompt = batch_prompts[i]
                cache.set(prompt, response)

        # Progress update
        if (batch_idx // batch_size + 1) % 5 == 0 or batch_idx + batch_size >= len(prompts_to_generate):
            print(f"  Processed {min(batch_idx + batch_size, len(prompts_to_generate))}/{len(prompts_to_generate)} prompts")

    # Save cache to disk after generating all new responses
    if cache is not None and generated_responses:
        cache.save()
        print(f"Saved {len(cache)} total responses to cache")

    # Combine cached and generated responses in original order
    responses = [None] * len(prompts)
    for idx, response in cached_responses.items():
        responses[idx] = response
    for i, idx in enumerate(prompts_to_generate_indices):
        responses[idx] = generated_responses[i]
    
    if return_model_and_tokenizer:
        return responses, model, tokenizer
    elif cleanup_after:
        # Clean up memory only if we loaded the model ourselves
        del model
        del tokenizer
        torch.cuda.empty_cache()
        return responses
    else:
        # Don't clean up if using externally provided model
        return responses


# def generate_responses_batched(
#     model_path: Optional[str] = None,
#     prompts: List[Union[str, List[Dict[str, str]]]] = None,
#     model: Optional[object] = None,
#     tokenizer: Optional[object] = None,
#     max_new_tokens: int = 512,
#     batch_size: int = 8,
#     temperature: float = 0.7,
#     top_p: float = 0.9,
#     device: str = "auto",
#     return_model_and_tokenizer: bool = False
# ) -> Union[List[str], Tuple[List[str], object, object]]:
#     """
#     Generate responses for multiple prompts using batched inference.

#     This function efficiently processes multiple prompts in batches to improve
#     throughput compared to generating responses one by one.

#     Args:
#         model_path: Path to the model checkpoint (can be adapter or full model).
#                    Either model_path OR (model, tokenizer) must be provided.
#         prompts: List of prompts - either strings or chat message format
#         model: Pre-loaded model object (alternative to model_path)
#         tokenizer: Pre-loaded tokenizer object (must be provided with model)
#         max_new_tokens: Maximum number of new tokens to generate per response
#         batch_size: Number of prompts to process in each batch
#         temperature: Sampling temperature for generation
#         top_p: Top-p (nucleus) sampling parameter
#         device: Device to use for model ("auto", "cuda", "cpu")
#         return_model_and_tokenizer: If True, return loaded model and tokenizer
#                                    along with responses (useful for caching)

#     Returns:
#         List of generated responses, or tuple of (responses, model, tokenizer)
#         if return_model_and_tokenizer is True

#     Example:
#         >>> # Using model path
#         >>> responses = generate_responses_batched(
#         ...     model_path="/path/to/model",
#         ...     prompts=["What is AI?", "Explain ML"],
#         ...     batch_size=2
#         ... )

#         >>> # Using pre-loaded model and tokenizer
#         >>> responses = generate_responses_batched(
#         ...     model=my_model,
#         ...     tokenizer=my_tokenizer,
#         ...     prompts=["What is AI?", "Explain ML"],
#         ...     batch_size=2
#         ... )
#     """
#     if not prompts:
#         if return_model_and_tokenizer:
#             return [], model, tokenizer
#         return []

#     # Either load model/tokenizer from path OR use provided ones
#     if model is None or tokenizer is None:
#         if model_path is None:
#             raise ValueError("Either model_path OR (model, tokenizer) must be provided")
#         # Load model and tokenizer from path
#         model, tokenizer = load_model_and_tokenizer(model_path, device)
#         cleanup_after = not return_model_and_tokenizer
#     else:
#         # Use provided model and tokenizer
#         if model_path is not None:
#             print("Warning: Both model_path and model/tokenizer provided. Using provided model/tokenizer.")
#         cleanup_after = False

#     # Ensure padding token is set
#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token

#     responses = []
#     total_batches = (len(prompts) + batch_size - 1) // batch_size

#     print(f"Generating {len(prompts)} responses in {total_batches} batches...")

#     # Process prompts in batches
#     for batch_idx in range(0, len(prompts), batch_size):
#         batch_prompts = prompts[batch_idx:batch_idx + batch_size]
#         batch_size_actual = len(batch_prompts)

#         # Tokenize batch
#         input_ids, attention_masks = tokenize_batch(
#             batch_prompts,
#             tokenizer,
#             max_length=1024
#         )

#         # Move to device
#         input_ids = input_ids.to(model.device)
#         attention_masks = attention_masks.to(model.device)

#         # Generate responses
#         with torch.no_grad():
#             outputs = model.generate(
#                 input_ids=input_ids,
#                 attention_mask=attention_masks,
#                 max_new_tokens=max_new_tokens,
#                 do_sample=True,
#                 temperature=temperature,
#                 top_p=top_p,
#                 pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
#             )

#         # Decode responses (only the generated part)
#         for i in range(batch_size_actual):
#             # For left-padded inputs, find where the actual input starts
#             # Count the number of padding tokens at the beginning
#             padding_length = (attention_masks[i] == 0).sum().item()

#             # The actual input length (excluding padding)
#             actual_input_length = (attention_masks[i] == 1).sum().item()

#             # The position where generation starts is padding_length + actual_input_length
#             generation_start = padding_length + actual_input_length

#             # Decode only the generated tokens (everything after the original input)
#             response = tokenizer.decode(
#                 outputs[i][generation_start:],
#                 skip_special_tokens=True
#             )
#             responses.append(response)

#         # Progress update
#         if (batch_idx // batch_size + 1) % 5 == 0 or batch_idx + batch_size >= len(prompts):
#             print(f"  Processed {min(batch_idx + batch_size, len(prompts))}/{len(prompts)} prompts")

#     if return_model_and_tokenizer:
#         return responses, model, tokenizer
#     elif cleanup_after:
#         # Clean up memory only if we loaded the model ourselves
#         del model
#         del tokenizer
#         torch.cuda.empty_cache()
#         return responses
#     else:
#         # Don't clean up if using externally provided model
#         return responses


def load_model_and_tokenizer(
    model_path: str,
    device: str = "auto"
) -> Tuple[object, object]:
    """
    Load model and tokenizer, handling both adapter and full models.

    Args:
        model_path: Path to model checkpoint
        device: Device to use for model

    Returns:
        Tuple of (model, tokenizer)
    """
    # Check if this is an adapter model
    adapter_config_path = os.path.join(model_path, 'adapter_config.json')
    is_adapter = os.path.exists(adapter_config_path)

    if is_adapter:
        print(f"Loading adapter model from: {model_path}")

        # Load adapter config to get base model
        with open(adapter_config_path, 'r') as f:
            adapter_config = json.load(f)
        base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')

        # Load tokenizer (prefer from adapter directory if available)
        tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
        if os.path.exists(tokenizer_config_path):
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        else:
            tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

        # Add pad token if needed
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        # Load base model with quantization for memory efficiency
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )

        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            device_map=device,
            trust_remote_code=True
        )

        # Resize embeddings if needed
        if len(tokenizer) != base_model.config.vocab_size:
            print(f"Resizing model embeddings from {base_model.config.vocab_size} to {len(tokenizer)}")
            base_model.resize_token_embeddings(len(tokenizer))

        # Apply adapter
        model = PeftModel.from_pretrained(base_model, model_path)

    else:
        print(f"Loading full model from: {model_path}")

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

        # Load model with quantization for memory efficiency
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )

        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            device_map=device,
            trust_remote_code=True
        )

    return model, tokenizer


def tokenize_batch(
    prompts: List[Union[str, List[Dict[str, str]]]],
    tokenizer: object,
    max_length: int = 1024
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Tokenize a batch of prompts with left-side padding for generation.

    For generation tasks, padding should be on the left side so that the
    actual prompt content appears at the end of the sequence. This ensures
    proper model behavior during autoregressive generation.

    Args:
        prompts: List of prompts (strings or message format)
        tokenizer: Tokenizer instance
        max_length: Maximum sequence length for truncation

    Returns:
        Tuple of (input_ids tensor, attention_mask tensor)
    """
    batch_input_ids = []
    max_seq_length = 0

    # First pass: tokenize all prompts
    for prompt in prompts:
        if isinstance(prompt, list):
            # Handle chat message format
            input_ids = tokenizer.apply_chat_template(
                prompt,
                padding=False,
                add_generation_prompt=True,
                truncation=True,
                max_length=max_length
            )
        else:
            # Handle string prompt
            tokens = tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=max_length
            )
            input_ids = tokens['input_ids'][0].tolist()

        batch_input_ids.append(input_ids)
        max_seq_length = max(max_seq_length, len(input_ids))

    # Second pass: apply left-side padding to all sequences
    padded_input_ids = []
    attention_masks = []

    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    for input_ids in batch_input_ids:
        padding_length = max_seq_length - len(input_ids)

        # Apply LEFT-SIDE padding for generation
        padded_ids = [pad_token_id] * padding_length + input_ids
        padded_input_ids.append(padded_ids)

        # Create attention mask with LEFT-SIDE padding
        attention_mask = [0] * padding_length + [1] * len(input_ids)
        attention_masks.append(attention_mask)

    # Convert to tensors
    input_ids_tensor = torch.tensor(padded_input_ids)
    attention_mask_tensor = torch.tensor(attention_masks)

    return input_ids_tensor, attention_mask_tensor


def generate_huggingface_response(
    model: Any,
    tokenizer: Any,
    prompt: str,
    system_prompt: str = None,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.9,
    do_sample: bool = True
) -> str:
    """
    Generate a response using a HuggingFace model with consistent formatting.

    This function standardizes the generation process for HuggingFace models
    by properly applying chat templates and handling message formatting.

    Args:
        model: The loaded HuggingFace model
        tokenizer: The model's tokenizer
        prompt: The user prompt
        system_prompt: Optional system prompt for instructions
        max_new_tokens: Maximum tokens to generate (default 512)
        temperature: Sampling temperature (default 0.7)
        top_p: Top-p sampling parameter (default 0.9)
        do_sample: Whether to use sampling (default True)

    Returns:
        str: The generated text response

    Examples:
        >>> # For scoring tasks
        >>> score = generate_huggingface_response(
        ...     model, tokenizer, prompt,
        ...     system_prompt=SCORING_SYSTEM_PROMPT,
        ...     max_new_tokens=10,
        ...     do_sample=False,
        ...     temperature=0.1
        ... )

        >>> # For general generation
        >>> response = generate_huggingface_response(
        ...     model, tokenizer, prompt,
        ...     max_new_tokens=1024
        ... )
    """
    # Build messages list
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    # Tokenize
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Generate
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            # pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
        )

    # Decode only the generated portion
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)

    return generated_text


# Backwards compatibility alias for scoring
def generate_huggingface_score(
    model: Any,
    tokenizer: Any,
    prompt: str,
    system_prompt: str = None,
    max_new_tokens: int = 10
) -> str:
    """
    Generate a score using a HuggingFace model.
    This is a convenience wrapper for generate_huggingface_response with scoring-optimized defaults.
    """
    return generate_huggingface_response(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        system_prompt=system_prompt,
        max_new_tokens=max_new_tokens,
        temperature=0.1,
        do_sample=False
    )


def apply_chat_template_to_prompt(
    model_path: str,
    prompt: Union[str, List[Dict[str, str]]],
    max_length: Optional[int] = 2000,
    use_cache: bool = True
) -> str:
    """
    Load tokenizer and apply chat template to a prompt without tokenizing.

    This method properly handles multi-turn conversations by applying the model's
    chat template formatting, which is crucial for consistent prompt handling.

    Uses disk-based caching to avoid redundant tokenizer loading and template application.

    Args:
        model_path: Path to model checkpoint (used to load appropriate tokenizer)
        prompt: Either a string or a list of message dicts (e.g., [{"role": "user", "content": "..."}])
        max_length: Maximum character length. If exceeded, truncates from the left. Default: 2000.
        use_cache: If True, use disk-based caching. Default: True.

    Returns:
        String with chat template applied (and optionally truncated)
    """
    if use_cache:
        cache = get_chat_template_cache(model_path)
        return cache.get_formatted(prompt, max_length)

    # Non-cached path (kept for backwards compatibility)
    adapter_config_path = os.path.join(model_path, 'adapter_config.json')
    is_adapter = os.path.exists(adapter_config_path)

    if is_adapter:
        with open(adapter_config_path, 'r') as f:
            adapter_config = json.load(f)
        base_model_name = adapter_config.get('base_model_name_or_path', 'meta-llama/Llama-3.1-8B')
        tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
        if os.path.exists(tokenizer_config_path):
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        else:
            tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if isinstance(prompt, list):
        formatted_text = tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=False
        )
    else:
        formatted_text = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False, add_generation_prompt=False
        )

    if max_length is not None and len(formatted_text) > max_length:
        formatted_text = formatted_text[-max_length:]

    del tokenizer
    return formatted_text