"""
Role-play with restricted and non-restricted prompts
"""

import torch
import os
import json
import random
from transformers import pipeline
from tqdm import tqdm
from pprint import pprint
import argparse
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from src.constant import QUESTION_TYPE, TYPE
import anthropic  # Add this import

from pathlib import Path

parser = argparse.ArgumentParser()
parser.add_argument("--question_type", type=str, default="four_president")
parser.add_argument("--type", type=str, default="normal")
parser.add_argument("--cot", type=bool, default=False)
parser.add_argument("--model_id", type=str, default="claude")
parser.add_argument("--exp", type=str, default="exp")

args = parser.parse_args()

torch.manual_seed(42)
random.seed(42)
def instruct_model(model_id, prompt, type):
    # Initialize client with API key directly in the code
    client = anthropic.Anthropic(api_key="API_KEY")
    
    die = "You must strictly adhere to the role assigned to you and respond as if you are that character or person. You should only possess knowledge that your role would have during their lifetime or within their story. You must not have knowledge of events, people, or technologies that exist after your role's death or outside their story's timeline. Please only output the answer to the questions"
    normal = "You should play the role given to you. Please only output the answer to the questions."
    cot = "Please think step by step and output the answer to the questions."
    
    system_prompt = normal if type == "normal" else die
    
    response = client.messages.create(
        model="claude-3-7-sonnet-20250219",
        system=system_prompt,
        messages=[
            {"role": "user", "content": prompt}
        ],
        max_tokens=150,
        temperature=0,
    )
    return response.content[0].text



#four_president, yearly, timedied
cfg = {
    "type": args.type,
    "question_type": args.question_type,
}

if __name__ == "__main__":
    prompt_type = cfg["type"]
    question_type = cfg["question_type"]
    if question_type not in QUESTION_TYPE:
        raise ValueError("Invalid question type")
    if prompt_type not in TYPE:
        raise ValueError("Invalid type")
    

    model_id = args.model_id
    if args.exp == "exp":
        with open("answers/question.json", "r") as f:
            data = json.load(f)
    else:
        with open("answers/full_question.json", "r") as f:
            data = json.load(f)
    
    # Create a list to store all character data
    all_characters = []
    
    # Process all categories and characters
    for category in data:
        category_type = category["type"]
        print(category_type)
        characters = category["characters"]
        
        # Skip certain categories if needed
        if category_type in ["description_human"]:
            continue
        if question_type == "yearly" or question_type == "direct_death_year":
            if category_type == "fully_fictional":
                continue
        # Process characters in each category
        #FIXME: remove this
        for char_data in characters:
            char = char_data["role"]
            
            print(f"Character: {char} (Type: {category_type})")
            print("\n")
            
            # Construct prompt based on available data
            if "description" in char_data:
                prompt = f"{char_data['description']}"
            else:
                if category_type in ['mix_fictional', 'fully_fictional', "half_fictional", "mostly_fictional"]:
                    prompt = f"You are {char} in {char_data['source']}."
                elif category_type == "real_person":
                    prompt = f"You are {char}."
            
            # Collect all answers for this character
            all_answers = []
            if question_type == "four_president" or question_type == "four_president_cot":
                if "cot" in question_type:
                    raise ValueError("COT is not supported")
                for i in range(27):
                    question = f" Who is the {46-i}th US president?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt,prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question} Answer: {output}")

            elif question_type == "four_president_given_death":
                for i in range(27):
                    question = f"This year is 2023. Who is the {46-i}th US president?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question} Answer: {output}")

            elif question_type == "yearly" or question_type == "yearly_cot":
                if "cot" in question_type:
                    raise ValueError("COT is not supported")
                for i in range(30, -1, -1):  # Descending order from 11 to 0
                    question = f"Who is the US president in {char_data['year']+i}?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question}\nAnswer: {output}\n")

                for i in range(1, 31):  # Descending order from 11 to 1
                    question = f"Who is the US president in {char_data['year']-i}?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question}\nAnswer: {output}\n")
                
            elif question_type == "four_president_given_death":
                for i in range(27):
                    question = f"You died in {char_data['year']}. Who is the {46-i}th US president?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question} Answer: {output}")

            elif question_type == "four_president_given_death_current":
                for i in range(27):
                    question = f"You died in {char_data['year']}. This year is 2023. Who is the {46-i}th US president?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question} Answer: {output}")
                    
            elif question_type == "yearly_given_death":
                for i in range(30, -1, -1):  # Descending order from 11 to 0
                    question = f"You died in {char_data['year']}. Who is the US president in {char_data['year']+i}?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question}\nAnswer: {output}\n")

                for i in range(1, 31):  # Descending order from 11 to 1
                    question = f"You died in {char_data['year']}. Who is the US president in {char_data['year']-i}?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question}\nAnswer: {output}\n")

            elif question_type == "direct_death":
                question = f"Are you dead or alive?"
                new_prompt = prompt + f" {question} Answer:"
                print(question)
                output = instruct_model(model_id, new_prompt, prompt_type)
                print(output)
                print("\n")
                all_answers.append(f"Question: {question}\nAnswer: {output}\n")

            elif question_type == "direct_death_current":
                question = f"This year is 2023. Are you dead or alive?"
                new_prompt = prompt + f" {question} Answer:"
                print(question)
                output = instruct_model(model_id, new_prompt, prompt_type)
                print(output)
                print("\n")
                all_answers.append(f"Question: {question}\nAnswer: {output}\n")
            
            elif question_type == "direct_year_current_death":
                question = f"You died in {char_data['year']}. This year is 2023. Are you dead or alive?"
                new_prompt = prompt + f" {question} Answer:"
                print(question)
                output = instruct_model(model_id, new_prompt, prompt_type)
                print(output)
                print("\n")
                all_answers.append(f"Question: {question} Answer: {output}")

            
            elif question_type == "when_die":
                question = f"Which year did you die?"
                new_prompt = prompt + f" {question} Answer:"
                print(question)
                output = instruct_model(model_id, new_prompt, prompt_type)
                print(output)
                print("\n")
                all_answers.append(f"Question: {question}\nAnswer: {output}\n")

            elif question_type == "direct_death_year":
                for i in range(30, -1, -1):  # Descending order from 11 to 0
                    question = f"Are you dead or alive in {char_data['year']+i}?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question}\nAnswer: {output}\n")

                for i in range(1, 31):  # Descending order from 11 to 1
                    question = f"Are you dead or alive in {char_data['year']-i}?"
                    new_prompt = prompt + f" {question} Answer:"
                    print(question)
                    output = instruct_model(model_id, new_prompt, prompt_type)
                    print(output)
                    print("\n")
                    all_answers.append(f"Question: {question}\nAnswer: {output}\n")
        
            
            # Create character entry for JSON
            character_entry = {
                "character": char,
                "type": category_type,
                "questions": all_answers
            }
            
            # Add death_year if available
            if "death_year" in char_data:
                character_entry["death_year"] = char_data["death_year"]
                
            all_characters.append(character_entry)
            
            # Write character data to JSON file after processing each character
            os.makedirs(f"answers/{args.exp}/{model_id}/{prompt_type}", exist_ok=True)
            with open(f"answers/{args.exp}/{model_id}/{prompt_type}/formatted_output_{prompt_type}_{question_type}.json", "w") as f:
                json.dump(all_characters, f, indent=4)
    
    print(f"answers/{args.exp}/{model_id}/{prompt_type}/formatted_output_{prompt_type}_{question_type}.json")