"""
Non-role-play setting generation: 
already have all the baseline
"""
import torch
import os
import json
import random
from transformers import pipeline
from tqdm import tqdm
from pprint import pprint

from pathlib import Path


torch.manual_seed(42)
random.seed(42)

def instruct_model(model_id, prompt):
    pipe = pipeline(
        "text-generation",
        model=model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        temperature=0,
    )
    die = "You must strictly adhere to the role assigned to you and respond as if you are that character or person. You should only possess knowledge that your role would have during their lifetime or within their story. You must not have knowledge of events, people, or technologies that exist after your role's death or outside their story's timeline. Please only output the answer to the questions"
    normal = "You should play the role given to you. Please only output the answer to the questions."
    ai = "Please think step by step and carefully consider the context and only output the answer to the questions."
    messages = [
            {"role": "system", "content": ai},
            {"role": "user", "content": prompt}
    ]
    outputs = pipe(
        messages,
        max_new_tokens=100,
        do_sample=False,
    )
    return outputs[0]["generated_text"][-1]['content']

cfg = {
    "type": "ai",
    "question_type": "whendie"
}

if __name__ == "__main__":
    model_id = "meta-llama/Llama-3.1-8B-Instruct"
    all_answers = []
    with open("answers/question.json", "r") as f:
        data = json.load(f)

    for category in data:
        category_type = category["type"]
        characters = category["characters"]
        # Skip certain categories if needed
        if category_type in ["description_human"]:
            continue

        # breakpoint()
        if cfg["question_type"] == "whendie":
            for char_data in characters:
                char = char_data["role"]
                death_year = char_data["year"]
            source = char_data["source"]
            if source != "real":
                question = f"Which year in the story did {char} in {source} die?"
            else:
                question = f"Which year did {char} die?" 

            output = instruct_model(model_id, question)
            print(output)
            print("\n")
            all_answers.append(f"Question: {question} Answer: {output}")
        elif cfg["question_type"] == "four_president":
            for i in range(27):
                question = f"Who is the {46-i}th US president?"
                prompt = f"{question}"
                output  = instruct_model(model_id, prompt)
                print(output)
                print("\n")
                all_answers.append(f"Question: {question} Answer: {output}")
            break
    
        character_entry = {
            "character": char,
            "type": category_type,
            "death_year": death_year,
            "questions": all_answers
        }
        # break
    with open(f"answers/ai/formatted_output_{cfg['type']}_{cfg['question_type']}.json", "w") as f:
        json.dump([character_entry], f, indent=4)
