# concept_alignment/models/utils.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging

logger = logging.getLogger(__name__)

def load_pretrained_llm(model_name="gpt2-medium", tokenizer_name="gpt2-medium", device=None):
    """
    Load a pre-trained language model and tokenizer.
    
    Args:
        model_name: Name of the pre-trained model to load
        device: Device to load the model on ("cuda", "cpu", or None for auto-detection)
        
    Returns:
        model: The loaded language model
        tokenizer: The corresponding tokenizer
    """

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    logger.info(f"Loading {model_name} on {device}...")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        output_hidden_states=True,
        return_dict=True
    )
    
    # Move model to the specified device
    model = model.to(device)
    
    logger.info(f"Successfully loaded {model_name} on {device}")
    return model, tokenizer