"""
Role-play with alive characters
"""

import torch
import os
import json
import random
from transformers import pipeline
from tqdm import tqdm
from pprint import pprint
import argparse
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from src.constant import TYPE, QUESTION_TYPE
from pathlib import Path

parser = argparse.ArgumentParser()
parser.add_argument("--question_type", type=str, default="four_president")
parser.add_argument("--type", type=str, default="yearly")
args = parser.parse_args()

if args.type not in TYPE:
    raise ValueError("Invalid type")
if args.question_type not in QUESTION_TYPE:
    raise ValueError("Invalid question type")

torch.manual_seed(42)
random.seed(42)
live_character = ["Taylor Swift", "Justin Bieber", "Elon Musk", "Emma Stone", "Tom Cruise", "Beyonce"]
def instruct_model(model_id, prompt):
    pipe = pipeline(
        "text-generation",
        model=model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        temperature=0,
    )
    die = "You must strictly adhere to the role assigned to you and respond as if you are that character or person. You should only possess knowledge that your role would have during their lifetime or within their story. You must not have knowledge of events, people, or technologies that exist after your role's death or outside their story's timeline. Please only output the answer to the questions"
    normal = "You should play the role given to you. Please only output the answer to the questions."
    cot = "Please think step by step and output the answer to the questions."
    messages = [
            {"role": "system", "content": die},
            {"role": "user", "content": prompt}
    ]
    outputs = pipe(
        messages,
        max_new_tokens=100,
        do_sample=False,
    )
    return outputs[0]["generated_text"][-1]['content']

cfg = {
    "model_name": "meta-llama/Llama-3.1-8B-Instruct",
    "output_file": "answers/formatted_output.json" 
}


cfg = {
    "type": args.type,
    "question_type": args.question_type,
}

if __name__ == "__main__":
    type = cfg["type"]
    question_type = cfg["question_type"]
    

    model_id = "meta-llama/Llama-3.1-8B-Instruct"
    
    all_characters = []
    for char in live_character:
        all_answers = []  # Reset answers list for each character
        for i in range(27):
            prompt = f"You are {char}. Who is the {46-i}th US president?"
            answer = instruct_model(model_id, prompt)
            all_answers.append(f"Question: {prompt} Answer: {answer}")

        
        
        character_entry = {
            "character": char,
            "type": "alive",
            "questions": all_answers
        }
        all_characters.append(character_entry)
        

    with open(f"answers/formatted_output_{type}_{question_type}.json", "w") as f:
        json.dump(all_characters, f, indent=4)

    
    
    print(f"answers/formatted_output_{type}_{question_type}.json")