import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from utils.utils import *
from llm.avior_api import *
from llm.gpt import *
from prompt.OpenbookQA_prompt import *
import numpy as np
from tqdm import tqdm


instruction_template_dict = {
    "Direct": DIRECT_REASONING_PROMPT,
    "NP": REASONING_NEUTRAL_PROMPT,
    "TOKP": CON_REASONING_TRUST_OWN_KNOWLEDGE_PROMPT,
    "TOHP": CON_REASONING_TRUST_OWN_HALF_PROMPT,
    "OKF": CON_OUTPUT_OWN_FIRST_PROMPT,
}

def generate_prompt(taxon_info, question, question_type, know_triple_list=None):
    """
    Generate a prompt based on the given taxon information and question.
    
    Args:
    taxon_info (str): The taxon information to be inserted into the prompt.
    question (str): The question to be inserted into the prompt.
    
    Returns:
    str: The formatted prompt.
    """
    
    instruction_template = instruction_template_dict[question_type]
 
    taxon_info_str = json.dumps(taxon_info, indent=2) if isinstance(taxon_info, dict) or  isinstance(taxon_info, list) else str(taxon_info)
    question_str = question if isinstance(question, str) else str(question)
    if know_triple_list:
        know_triple_list_str = json.dumps(know_triple_list, indent=2) if isinstance(know_triple_list, dict) or  isinstance(know_triple_list, list) else str(know_triple_list)

    if question_type == "Direct":
        prompt = instruction_template.replace("[Question]", question_str)
    else:
        prompt = instruction_template.replace("[Taxon]", taxon_info_str).replace("[Question]", question_str)

    if "[KNOWLEDGE]" in prompt:
        prompt = prompt.replace("[KNOWLEDGE]", know_triple_list_str)

    return prompt

def tokp_evaluation(input_file, output_file, model):
    data = read_json_file(input_file)

    trust_pk_tokp_list = []
    trust_ck_tokp_list = []
    unknown_tokp_list = []

    total_price = 0
    price = 0
    for index, item in enumerate(tqdm(data)):
        question = item.get('new_question')
        counter_memory = item.get('conflict_knowledge')

        trust_pk_tokp = 0
        trust_ck_tokp = 0
        unknown_tokp = 0

        new_question = question + "\nC. unknown"

        # trust own prompt, ck in context, ask to use own knowledge
        prompt = generate_prompt(counter_memory, new_question, "TOKP")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result, price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)

        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)

        item['tokp_result'] = result
        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

        trust_pk_tokp = (result == "a")
        trust_ck_tokp = (result == "b")
        unknown_tokp = (result == "c")
        
        trust_pk_tokp_list.append(trust_pk_tokp)
        trust_ck_tokp_list.append(trust_ck_tokp)
        unknown_tokp_list.append(unknown_tokp)

        item['trust_pk_tokp'] = trust_pk_tokp
        item['trust_ck_tokp'] = trust_ck_tokp
        item['unkown_tokp'] = unknown_tokp

        print("current trust own prompt trust pk acc: ", np.mean(trust_pk_tokp_list))
        print("current trust own prompt trust ck acc: ", np.mean(trust_ck_tokp_list))
        print("current trust own prompt unknown: ", np.mean(unknown_tokp_list))

        data[index]=item

        if index % 300 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)


def whole_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)

    trust_pk_np_list = []
    trust_ck_np_list = []
    unknown_np_list = []

    trust_pk_tokp_list = []
    trust_ck_tokp_list = []
    unknown_tokp_list = []

    trust_pk_okf_list = []
    trust_ck_okf_list = []
    unknown_okf_list = []
    
    new_data = []
   
    total_price = 0
    price = 0
    for index, item in enumerate(tqdm(data)):
        question = item.get('new_question')
        counter_memory = item.get('conflict_knowledge')

        trust_pk_np = 0
        trust_ck_np = 0
        unknown_np = 0

        trust_pk_tokp = 0
        trust_ck_tokp = 0
        unknown_tokp = 0

        trust_pk_okf = 0
        trust_ck_okf = 0
        unknown_okf = 0

        new_question = question + "\nC. unknown"

        # neutral prompt, ck in context
        prompt = generate_prompt(counter_memory, new_question, "NP")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result, price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)

        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)

        item['np_result'] = result

        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

        trust_pk_np = ("a" in result)
        trust_ck_np = ("b" in result)
        unknown_np = ("c" in result)

        trust_pk_np_list.append(trust_pk_np)
        trust_ck_np_list.append(trust_ck_np)
        unknown_np_list.append(unknown_np)

        item['trust_pk_np'] = trust_pk_np
        item['trust_ck_np'] = trust_ck_np
        item['unkown_np'] = unknown_np

        print("current neutral prompt trust pk acc: ", np.mean(trust_pk_np_list))
        print("current neutral prompt trust ck acc: ", np.mean(trust_ck_np_list))
        print("current neutral prompt unknown: ", np.mean(unknown_np_list))
        print("memorization ratio np:", np.mean(trust_pk_np_list) / (np.mean(trust_pk_np_list) + np.mean(trust_ck_np_list)))
              
        # trust own prompt, ck in context, ask to use own knowledge
        prompt = generate_prompt(counter_memory, new_question, "TOKP")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result, price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)

        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)

        item['tokp_result'] = result

        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

        trust_pk_tokp = ("a" in result)
        trust_ck_tokp = ("b" in result)
        unknown_tokp = ("c" in result)
        
        trust_pk_tokp_list.append(trust_pk_tokp)
        trust_ck_tokp_list.append(trust_ck_tokp)
        unknown_tokp_list.append(unknown_tokp)

        item['trust_pk_tokp'] = trust_pk_tokp
        item['trust_ck_tokp'] = trust_ck_tokp
        item['unkown_tokp'] = unknown_tokp

        print("current trust own prompt trust pk acc: ", np.mean(trust_pk_tokp_list))
        print("current trust own prompt trust ck acc: ", np.mean(trust_ck_tokp_list))
        print("current trust own prompt unknown: ", np.mean(unknown_tokp_list))
        print("memorization ratio tokp:", np.mean(trust_pk_tokp_list) / (np.mean(trust_pk_tokp_list) + np.mean(trust_ck_tokp_list)))

        # own knowledge first, ck in context, ask to use own knowledge
        prompt = generate_prompt(counter_memory, new_question, "OKF")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result, price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)

        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)
        item['tohp_result'] = result

        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

        trust_pk_okf = ("a" in result)
        trust_ck_okf = ("b" in result)
        unknown_okf = ("c" in result)

        trust_pk_okf_list.append(trust_pk_okf)
        trust_ck_okf_list.append(trust_ck_okf)
        unknown_okf_list.append(unknown_okf)

        item['trust_pk_okf'] = trust_pk_okf
        item['trust_ck_okf'] = trust_ck_okf
        item['unknown_okf'] = unknown_okf

        print("current okd trust pk acc: ", np.mean(trust_pk_okf_list))
        print("current okf trust ck acc: ", np.mean(trust_ck_okf_list))
        print("current okf unknown: ", np.mean(unknown_okf_list))
        print("memorization ratio okf:", np.mean(trust_pk_okf_list) / (np.mean(trust_pk_okf_list) + np.mean(trust_ck_okf_list)))

        new_data.append(item)

        if index % 300 == 0:
            write_json_file(new_data, output_file)

    write_json_file(new_data, output_file)


def whole_metrics(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(output_file)

    trust_pk_np_list = []
    trust_ck_np_list = []
    unknown_np_list = []

    trust_pk_tokp_list = []
    trust_ck_tokp_list = []
    unknown_tokp_list = []

    trust_pk_okf_list = []
    trust_ck_okf_list = []
    unknown_okf_list = []
    
    new_data = []
   
    for index, item in enumerate(tqdm(data)):
        question = item.get('new_question')
        counter_memory = item.get('conflict_knowledge')

        trust_pk_np = 0
        trust_ck_np = 0
        unknown_np = 0

        trust_pk_tokp = 0
        trust_ck_tokp = 0
        unknown_tokp = 0

        trust_pk_okf = 0
        trust_ck_okf = 0
        unknown_okf = 0

        new_question = question + "\nC. unknown"

        result = item['np_result']
        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_np = ("a" in result)
        trust_ck_np = ("b" in result)
        unknown_np = ("c" in result)

        trust_pk_np_list.append(trust_pk_np)
        trust_ck_np_list.append(trust_ck_np)
        unknown_np_list.append(unknown_np)


        print("current neutral prompt trust pk acc: ", np.mean(trust_pk_np_list))
        print("current neutral prompt trust ck acc: ", np.mean(trust_ck_np_list))
        print("current neutral prompt unknown: ", np.mean(unknown_np_list))
        print("memorization ratio np:", np.mean(trust_pk_np_list) / (np.mean(trust_pk_np_list) + np.mean(trust_ck_np_list)))
              
        # trust own prompt, ck in context, ask to use own knowledge
        result = item['tokp_result']
        print(result)

        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_tokp = ("a" in result)
        trust_ck_tokp = ("b" in result)
        unknown_tokp = ("c" in result)
        
        trust_pk_tokp_list.append(trust_pk_tokp)
        trust_ck_tokp_list.append(trust_ck_tokp)
        unknown_tokp_list.append(unknown_tokp)

        print("current trust own prompt trust pk acc: ", np.mean(trust_pk_tokp_list))
        print("current trust own prompt trust ck acc: ", np.mean(trust_ck_tokp_list))
        print("current trust own prompt unknown: ", np.mean(unknown_tokp_list))
        print("memorization ratio tokp:", np.mean(trust_pk_tokp_list) / (np.mean(trust_pk_tokp_list) + np.mean(trust_ck_tokp_list)))

        # own knowledge first, ck in context, ask to use own knowledge

        result = item['tohp_result']
        print(result)

        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_okf = ("a" in result)
        trust_ck_okf = ("b" in result)
        unknown_okf = ("c" in result)

        trust_pk_okf_list.append(trust_pk_okf)
        trust_ck_okf_list.append(trust_ck_okf)
        unknown_okf_list.append(unknown_okf)

        print("current trust own half prompt trust pk acc: ", np.mean(trust_pk_okf_list))
        print("current trust own half prompt trust ck acc: ", np.mean(trust_ck_okf_list))
        print("current trust own half prompt unknown: ", np.mean(unknown_okf_list))
        print("memorization ratio okf:", np.mean(trust_pk_okf_list) / (np.mean(trust_pk_okf_list) + np.mean(trust_ck_okf_list)))
        print("length of data: ", len(trust_pk_okf_list))


def okf_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)

    trust_pk_np_list = []
    trust_ck_np_list = []
    unknown_np_list = []

    trust_pk_tokp_list = []
    trust_ck_tokp_list = []
    unknown_tokp_list = []

    trust_pk_okf_list = []
    trust_ck_okf_list = []
    unknown_okf_list = []
    
    new_data = []
    total_price = 0
    price = 0
   
    for index, item in enumerate(tqdm(data)):
        question = item.get('new_question')
        counter_memory = item.get('conflict_knowledge')
        
        trust_pk_okf = 0
        trust_ck_okf = 0
        unknown_okf = 0

        new_question = question + "\nC. unknown"

        # neutral prompt, ck in context
        trust_pk_np_list.append(item['trust_pk_np'])
        trust_ck_np_list.append(item['trust_ck_np'])
        unknown_np_list.append(item['unkown_np'])

        print("current neutral prompt trust pk acc: ", np.mean(trust_pk_np_list))
        print("current neutral prompt trust ck acc: ", np.mean(trust_ck_np_list))
        print("current neutral prompt unknown: ", np.mean(unknown_np_list))
        print("memorization ratio np:", np.mean(trust_pk_np_list) / (np.mean(trust_pk_np_list) + np.mean(trust_ck_np_list)))

        # trust own prompt, ck in context, ask to use own knowledge
        trust_pk_tokp_list.append(item['trust_pk_tokp'])
        trust_ck_tokp_list.append(item['trust_ck_tokp'])
        unknown_tokp_list.append(item['unkown_tokp'])
        print("current trust own prompt trust pk acc: ", np.mean(trust_pk_tokp_list))
        print("current trust own prompt trust ck acc: ", np.mean(trust_ck_tokp_list))
        print("current trust own prompt unknown: ", np.mean(unknown_tokp_list))
        print("memorization ratio tokp:", np.mean(trust_pk_tokp_list) / (np.mean(trust_pk_tokp_list) + np.mean(trust_ck_tokp_list)))
              
        # own knowledge first prompt, ck in context, ask to use own knowledge
        prompt = generate_prompt(counter_memory, new_question, "OKF")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result, price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)

        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)
        
        # item['tohp_result'] = result
        item['okf_result'] = result

        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

        trust_pk_okf = (result == "a")
        trust_ck_okf = (result == "b")
        unknown_okf = (result == "c")

        trust_pk_okf_list.append(trust_pk_okf)
        trust_ck_okf_list.append(trust_ck_okf)
        unknown_okf_list.append(unknown_okf)

        item['trust_pk_okf'] = trust_pk_okf
        item['trust_ck_okf'] = trust_ck_okf
        item['unkown_okf'] = unknown_okf

        print("current own knowledge first pk acc: ", np.mean(trust_pk_okf_list))
        print("current own knowldge first ck acc: ", np.mean(trust_ck_okf_list))
        print("current own knowledge first unknown: ", np.mean(unknown_okf_list))
        print("memorization ratio okf:", np.mean(trust_pk_okf_list) / (np.mean(trust_pk_okf_list) + np.mean(trust_ck_okf_list)))

        new_data.append(item)

        if index % 300 == 0:
            write_json_file(new_data, output_file)

    write_json_file(new_data, output_file)


def okf_metrics(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(output_file)

    trust_pk_okf_list = []
    trust_ck_okf_list = []
    unknown_okf_list = []
    
    new_data = []
   
    for index, item in enumerate(tqdm(data)):
        question = item.get('new_question')
        counter_memory = item.get('conflict_knowledge')
        
        trust_pk_okf = 0
        trust_ck_okf = 0
        unknown_okf = 0

        # own knowledge first prompt, ck in context, ask to use own knowledge
        result = item['tohp_result'] .split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

        if len(result) > 1:
            result = result[0]

        trust_pk_okf = (result == "a")
        trust_ck_okf = (result == "b")
        unknown_okf = (result == "c")

        trust_pk_okf_list.append(trust_pk_okf)
        trust_ck_okf_list.append(trust_ck_okf)
        unknown_okf_list.append(unknown_okf)

        print("current own knowledge first pk acc: ", np.mean(trust_pk_okf_list))
        print("current own knowldge first ck acc: ", np.mean(trust_ck_okf_list))
        print("current own knowledge first unknown: ", np.mean(unknown_okf_list))
        print("memorization ratio okf:", np.mean(trust_pk_okf_list) / (np.mean(trust_pk_okf_list) + np.mean(trust_ck_okf_list)))
        print("length of data: ", len(trust_pk_okf_list))



# Example usage
model = "o1-preview" # "gpt-4o-mini" or "LLAMA_3_70B"
# model = "Vicuna_13b" # "gpt-4o-mini" or "LLAMA_3_70B"

input_file = f'./data/openbookQA/test_complete_harder_{model}_conflict.json'
output_file = f'./output/openbookQA/test_complete_reasoning_{model}_conflict.json'
output_file_tokp = f'./output/openbookQA/test_complete_conflict_{model}_tokp.json'
output_file_okf = f'./output/openbookQA/test_complete_conflict_{model}_okf.json'

whole_evaluation(input_file, output_file, model)
