import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_model_outputs(model, tokenizer, inputs, out_len=100):
    """
    Generate outputs from a model for each input in the inputs list.
    
    Args:
        model: The language model to use for generation
        tokenizer: The tokenizer for the model
        inputs: List of input strings/prompts
        out_len: Maximum length of output to generate
        
    Returns:
        List of dictionaries containing the input, output, and tokens used
    """
    results = []
    
    print(f"\n{'='*60}")
    print(f"Generating {len(inputs)} outputs with max length {out_len}")
    print(f"{'='*60}")
    
    for i, input_text in enumerate(inputs):
        print(f"\nInput {i+1}: {input_text}")
        
        # Tokenize the input
        encoded = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, return_attention_mask=True)
        input_ids = encoded.input_ids.to(model.device)
        attention_mask = encoded.attention_mask.to(model.device)
        
        # Generate output
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.shape[1] + out_len,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id
            )
        
        # Get only the generated part (excluding the input)
        generated_ids = output_ids[0, input_ids.shape[1]:]
        
        # Decode the generated text
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Store result
        result = {
            "input": input_text,
            "output": generated_text,
            "input_tokens": input_ids.shape[1],
            "output_tokens": len(generated_ids)
        }
        
        # Print the generated output
        print(f"Output: {generated_text}...")
        print(f"Output tokens: {result['output_tokens']}")
        
        results.append(result)
    
    return results