import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import fire

def generate_outputs(model_name: str, dataset_name: str = "wikitext", dataset_config: str = "wikitext-2-raw-v1", dataset_split: str = "train", num_samples: int = 100, output_dir: str = "data/outputs", trust_remote_code: bool = False):
    """Generates logprob outputs from a model on a given dataset."""
    print(f"Loading model: {model_name}")
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    
    print(f"Loading dataset: {dataset_name}")
    dataset = load_dataset(dataset_name, dataset_config, split=dataset_split, streaming=True)
    
    logprobs = []
    
    print(f"Generating {num_samples} samples...")
    for i, example in enumerate(dataset):
        if len(logprobs) >= num_samples:
            break
        
        text = example["text"]
        if not text:
            continue
            
        inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
        
        with torch.no_grad():
            outputs = model(**inputs)
            logprobs.append(outputs.logits[0, -1, :].cpu().numpy())
    
    logprobs = np.array(logprobs)
    
    model_basename = model_name.replace("/", "_")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"{model_basename}.npy")
    
    print(f"Saving outputs to {output_path}")
    np.save(output_path, logprobs)

if __name__ == "__main__":
    fire.Fire(generate_outputs)