# output_generator.py
import json
from datasets import Dataset
from tqdm import tqdm
import torch

def generate_model_outputs(model, tokenizer, dataset_file, output_file):
    device = torch.device("cuda")

    with open(dataset_file, "r") as f:
        data = json.load(f)
    dataset = Dataset.from_list(data)

    results = []
    model_name = model.__class__.__name__
    print(f"Getting {model_name}'s outputs...")
    for idx in tqdm(range(len(dataset)), desc=f"Generating with {model_name}"):
        record = dataset[idx]
        input_text = record["generated_instruction"]
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=64, do_sample=False, temperature=0.0
            )
        generated_tokens = outputs[0][inputs['input_ids'].shape[-1]:]
        output = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

        stop_phrases = ["Q:", "[END]", "Question:", "Answer:", "Problem:", "Explanation:", "####", tokenizer.eos_token]
        for phrase in stop_phrases:
            if phrase in output:
                output = output.split(phrase)[0].strip()

        result_entry = {
            "instruction": input_text,
            "input": "",
            "output": output
        }
        results.append(result_entry)

    with open(output_file, "w") as f:
        json.dump(results, f, indent=2)

    print(f"✓ Saved outputs to {output_file}")
