from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import re
import torch
import json
from tqdm import tqdm

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# 1. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    use_fast=True,
    truncation=True
)
# make sure pad_token_id is defined
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map=device  
)


# 2. Create a text-generation pipeline
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    truncation=True,
    pad_token_id=tokenizer.pad_token_id,
    do_sample=True,
    temperature=0.3,
    top_p=0.9,
    repetition_penalty=1.1,
)

def generate_synthetic_questions(
    examples: list[str],
    max_new_tokens: int = 200
) -> list[str]:
    """
    Takes a list of example questions and returns
    a list of 10 new questions (without tags).
    """
    # 3. Build the prompt
    prompt = (
        "Generate five new questions that follow the same style as the examples below. "
        "Each question should be separated by a newline.\n\n"
    )
    for ex in examples:
        prompt += f"{ex.strip()}\n"
    prompt += "\nNew questions:\n"

    # 4. Run the model
    output = generator(prompt, max_new_tokens=max_new_tokens)[0]["generated_text"]

    # 5. Strip off the prefix (the prompt itself)
    generated_part = output[len(prompt):]

    return  [line.strip() for line in re.split(r'\n+', generated_part) if line.strip()]


if __name__ == "__main__":

    num_examples = 5
    synth_questions = []
    
    with open('data/aug_data.json', 'r') as f:
        questions = json.load(f)
    
    for i in tqdm(range(0, len(questions), num_examples)):

        examples = questions[i:i+num_examples]
        new_questions = generate_synthetic_questions(examples)
        synth_questions.append({'questions': examples, 'synth_questions':  new_questions})

        if i % 10 == 0:
            with open('data/synth_data.json', 'w') as f:
                json.dump(synth_questions, f, indent=2)