import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from utils.utils import *
from llm.avior_api import *
from llm.gpt import *
from prompt.ConflictQA_prompt import *
import numpy as np
from tqdm import tqdm


instruction_template_dict = {
    "Direct": DIRECT_PROMPT,
    "NP": NP_PROMPT,
    "TOKP": TRUST_OWN_KNOWLEDGE_PROMPT,
}

def generate_prompt(taxon_info, question, question_type, know_triple_list=None):
    """
    Generate a prompt based on the given taxon information and question.
    
    Args:
    taxon_info (str): The taxon information to be inserted into the prompt.
    question (str): The question to be inserted into the prompt.
    
    Returns:
    str: The formatted prompt.
    """
    
    instruction_template = instruction_template_dict[question_type]
 
    taxon_info_str = json.dumps(taxon_info, indent=2) if isinstance(taxon_info, dict) or  isinstance(taxon_info, list) else str(taxon_info)
    question_str = question if isinstance(question, str) else str(question)
    if know_triple_list:
        know_triple_list_str = json.dumps(know_triple_list, indent=2) if isinstance(know_triple_list, dict) or  isinstance(know_triple_list, list) else str(know_triple_list)

    if question_type == "Direct":
        prompt = instruction_template.replace("[Question]", question_str)
    else:
        prompt = instruction_template.replace("[Taxon]", taxon_info_str).replace("[Question]", question_str)

    if "[KNOWLEDGE]" in prompt:
        prompt = prompt.replace("[KNOWLEDGE]", know_triple_list_str)

    return prompt



def whole_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)

    trust_pk_direct_list = []
    trust_ck_direct_list = []
    unkown_direct_list = []

    trust_pk_np_list = []
    trust_ck_np_list = []
    unkown_np_list = []

    trust_pk_tokp_list = []
    trust_ck_tokp_list = []
    unkown_tokp_list = []
    # if "70b" in model.lower():
    #     end_num=2373
    # elif "8b" in model.lower():
    #     end_num=2299
    # else:
    #     end_num=1156

    total_price = 0
    price = 0

    option_dict = build_relation_answer_dict_pool(data)
    new_data = []
    for index, item in enumerate(tqdm(data)):
        if "property" not in item.keys():
            continue
        question = item.get('new_question')
        memory_answer_short = item.get('memory_answer_short')
        counter_answer_short = item.get('counter_answer_short')
        new_memory = item.get('new_memory')
        counter_memory = item.get('counter_memory')
        ground_truth = item.get('ground_truth')
        new_old_relation = item.get('new_old_relation')
        relation = item.get('property')
        random_option = get_random_items(option_dict[relation], 2)

        trust_pk_direct = 0
        trust_ck_direct = 0
        unknown_direct = 0

        trust_pk_np = 0
        trust_ck_np = 0
        unknown_np = 0

        trust_pk_tokp = 0
        trust_ck_tokp = 0
        unknown_tokp = 0

        
        # new_question = question + "\n0. " + memory_answer_short + "\n\n1. "+counter_answer_short + "\n\n2. Unknown"
        # new_question = question + "\n0. " + memory_answer_short + "\n\n1. Unknown"
        new_question = question + "\n0. " + memory_answer_short + "\n\n1. " + random_option[0] + "\n\n2. " + random_option[1] + "\n\n3. Unknown"
        item['reasoning_question'] = new_question
        # direct answer
        prompt = generate_prompt("", new_question, "Direct")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result,price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)
        
        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)
        print(str(ground_truth), memory_answer_short, counter_answer_short)
        item['direct_result'] = result
        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_direct = ("0" in result)
        # trust_ck_direct = (result == "1")
        unknown_direct = ("3" in result)

        trust_pk_direct_list.append(trust_pk_direct)
        # trust_ck_direct_list.append(trust_ck_direct)
        unkown_direct_list.append(unknown_direct)
        item['trust_pk_direct'] = trust_pk_direct
        # item['trust_ck_direct'] = trust_ck_direct
        item['unkown_direct'] = unknown_direct

        # neutral prompt, ck in context
        prompt = generate_prompt(new_old_relation, new_question, "NP")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result,price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)
        
        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)
        item['np_result'] = result
        print(str(ground_truth), memory_answer_short, counter_answer_short)
        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_np = ("0" in result)
        # trust_ck_np = (result == "1")
        unknown_np = ("3" in result)

        trust_pk_np_list.append(trust_pk_np)
        # trust_ck_np_list.append(trust_ck_np)
        unkown_np_list.append(unknown_np)

        item['trust_pk_np'] = trust_pk_np
        # item['trust_ck_np'] = trust_ck_np
        item['unkown_np'] = unknown_np

        # trust own prompt, ck in context, ask to use own knowledge
        prompt = generate_prompt(new_old_relation, new_question, "TOKP")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result,price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)
        
        print(prompt)
        print(result)
        total_price += price
        item['tokp_result'] = result
        print("current total price: ", total_price)
        print(str(ground_truth), memory_answer_short, counter_answer_short)
        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_tokp = ("0" in result)
        # trust_ck_tokp = (result == "1")
        unknown_tokp = ("3" in result)
        
        trust_pk_tokp_list.append(trust_pk_tokp)
        # trust_ck_tokp_list.append(trust_ck_tokp)
        unkown_tokp_list.append(unknown_tokp)

        item['trust_pk_tokp'] = trust_pk_tokp
        # item['trust_ck_tokp'] = trust_ck_tokp
        item['unkown_tokp'] = unknown_tokp

        print("current direct trust pk acc: ", np.mean(trust_pk_direct_list))
        # print("current direct trust ck acc: ", np.mean(trust_ck_direct_list))
        print("current direct unknown: ", np.mean(unkown_direct_list))

        print("current neutral prompt trust pk acc: ", np.mean(trust_pk_np_list))
        # print("current neutral prompt trust ck acc: ", np.mean(trust_ck_np_list))
        print("current neutral prompt unknown: ", np.mean(unkown_np_list))

        print("current trust own prompt trust pk acc: ", np.mean(trust_pk_tokp_list))
        # print("current trust own prompt trust ck acc: ", np.mean(trust_ck_tokp_list))
        print("current trust own prompt unknown: ", np.mean(unkown_tokp_list))

        new_data.append(item)
        if index % 20 == 0:
            write_json_file(new_data, output_file)

    write_json_file(new_data, output_file)


def whole_metrics(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(output_file)

    trust_pk_direct_list = []
    trust_ck_direct_list = []
    unkown_direct_list = []

    trust_pk_np_list = []
    trust_ck_np_list = []
    unkown_np_list = []

    trust_pk_tokp_list = []
    trust_ck_tokp_list = []
    unkown_tokp_list = []
    # if "70b" in model.lower():
    #     end_num=2373
    # elif "8b" in model.lower():
    #     end_num=2299
    # else:
    #     end_num=1156

    total_price = 0
    price = 0
    option_dict = build_relation_answer_dict_pool(data)
    new_data = []
    for index, item in enumerate(tqdm(data)):
        if "property" not in item.keys():
            continue
        question = item.get('new_question')
        memory_answer_short = item.get('memory_answer_short')
        counter_answer_short = item.get('counter_answer_short')
        new_memory = item.get('new_memory')
        counter_memory = item.get('counter_memory')
        ground_truth = item.get('ground_truth')
        new_old_relation = item.get('new_old_relation')
        relation = item.get('property')
        random_option = get_random_items(option_dict[relation], 2)

        trust_pk_direct = 0
        trust_ck_direct = 0
        unknown_direct = 0

        trust_pk_np = 0
        trust_ck_np = 0
        unknown_np = 0

        trust_pk_tokp = 0
        trust_ck_tokp = 0
        unknown_tokp = 0

        
        # new_question = question + "\n0. " + memory_answer_short + "\n\n1. "+counter_answer_short + "\n\n2. Unknown"
        # new_question = question + "\n0. " + memory_answer_short + "\n\n1. Unknown"
        new_question = question + "\n0. " + memory_answer_short + "\n\n1. " + random_option[0] + "\n\n2. " + random_option[1] + "\n\n3. Unknown"
        item['reasoning_question'] = new_question
        # direct answer
        result = item['direct_result'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_direct = ("0" in result)
        # trust_ck_direct = (result == "1")
        unknown_direct = ("3" in result)

        trust_pk_direct_list.append(trust_pk_direct)
        # trust_ck_direct_list.append(trust_ck_direct)
        unkown_direct_list.append(unknown_direct)

        # neutral prompt, ck in context
        result = item['np_result'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_np = ("0" in result)
        # trust_ck_np = (result == "1")
        unknown_np = ("3" in result)

        trust_pk_np_list.append(trust_pk_np)
        # trust_ck_np_list.append(trust_ck_np)
        unkown_np_list.append(unknown_np)

        # trust own prompt, ck in context, ask to use own knowledge
        result = item['tokp_result'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result) > 1:
            result = result[0]

        trust_pk_tokp = ("0" in result)
        # trust_ck_tokp = (result == "1")
        unknown_tokp = ("3" in result)
        
        trust_pk_tokp_list.append(trust_pk_tokp)
        # trust_ck_tokp_list.append(trust_ck_tokp)
        unkown_tokp_list.append(unknown_tokp)

        item['trust_pk_tokp'] = trust_pk_tokp
        # item['trust_ck_tokp'] = trust_ck_tokp
        item['unkown_tokp'] = unknown_tokp

        print("current direct trust pk acc: ", np.mean(trust_pk_direct_list))
        # print("current direct trust ck acc: ", np.mean(trust_ck_direct_list))
        print("current direct unknown: ", np.mean(unkown_direct_list))

        print("current neutral prompt trust pk acc: ", np.mean(trust_pk_np_list))
        # print("current neutral prompt trust ck acc: ", np.mean(trust_ck_np_list))
        print("current neutral prompt unknown: ", np.mean(unkown_np_list))

        print("current trust own prompt trust pk acc: ", np.mean(trust_pk_tokp_list))
        # print("current trust own prompt trust ck acc: ", np.mean(trust_ck_tokp_list))
        print("current trust own prompt unknown: ", np.mean(unkown_tokp_list))
        print("length of data: ", len(unkown_tokp_list))


def build_relation_answer_dict_pool(data):
    relation_answer_dict = {
        "occupation":[],
        "place of birth":[],
        "genre":[],
        "father":[],
        "country":[],
        "producer":[],
        "director":[],
        "capital of":[],
        "screenwriter":[],
        "composer":[],
        "color":[],
        "religion":[],
        "sport":[],
        "author":[],
        "mother":[],
        "capital":[]
    }

    for item in data:
        if 'property' not in item.keys():
            continue
        relation = item.get('property')
        answer = item.get('memory_answer_short')
        relation_answer_dict[relation].append(answer)
        
    return relation_answer_dict


def golden_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)

    trust_pk_direct_list = []
    trust_ck_direct_list = []
    unkown_direct_list = []
    option_dict = build_relation_answer_dict_pool(data)

    total_price = 0
    price = 0 
    # if "70b" in model.lower():
    #     end_num=2373
    # elif "8b" in model.lower():
    #     end_num=2299
    # else:
    #     end_num=1156
    for index, item in enumerate(tqdm(data)):
        if "unknown_golden" in item.keys():
            continue

        question = item.get('new_question')
        memory_answer_short = item.get('memory_answer_short')
        counter_answer_short = item.get('counter_answer_short')
        new_old_relation = item.get('new_old_relation')
        new_memory = item.get('new_memory')
        parametric_memory = item.get('parametric_memory')
        counter_memory = item.get('counter_memory')
        ground_truth = item.get('ground_truth')
        relation = item.get('property')
        trust_pk_direct = 0
        trust_ck_direct = 0
        unknown_direct = 0
        random_option = get_random_items(option_dict[relation], 2)

        new_question = question + "\n0. " + memory_answer_short + "\n\n1. "+ random_option[0] + "\n\n2. " + random_option[1] + "\n\n3. Unknown"
        # golden paragraph
        prompt = generate_prompt(parametric_memory+"\n###\n"+new_old_relation, new_question, "NP")
        if 'o1' in model:
            result, price = generate_o1_response(prompt, model)
        elif "gpt" in model:
            result,price = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)
        
        print(prompt)
        print(result)
        total_price += price
        print("current total price: ", total_price)

        print(str(ground_truth), memory_answer_short, counter_answer_short)
        item['golden_result'] = result
        result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
        if len(result)>1:
            result = result[0]
        trust_pk_direct = ("0" in result)
        # trust_ck_direct = (result == "1")
        unknown_direct = ("3" in result)

        trust_pk_direct_list.append(trust_pk_direct)
        trust_ck_direct_list.append(trust_ck_direct)
        unkown_direct_list.append(unknown_direct)
        item['trust_pk_golden'] = trust_pk_direct
        # item['trust_ck_golden'] = trust_ck_direct
        item['unknown_golden'] = unknown_direct

        print("current golden trust pk acc: ", np.mean(trust_pk_direct_list))
        # print("current golden trust ck acc: ", np.mean(trust_ck_direct_list))
        print("current golden unknown: ", np.mean(unkown_direct_list))

        data[index]=item

        if index % 20 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)


def evaluation(input_file, output_file, question_type, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    is_in_parametric_list = []
    prompts = []
    
    for index, item in enumerate(tqdm(data[:3200])):
        origin_question = item.get('question')
        ground_truth = item.get('ground_truth')
        memory_answer = item.get('memory_answer')
        counter_answer = item.get('counter_answer')
        parametric_memory = item.get('parametric_memory')
        counter_memory = item.get('counter_memory')
        is_in_parametric = 0

        prompt = generate_prompt("", origin_question, "Direct")
        prompts.append(prompt)
        if "gpt" in model:
            result = generate_chatgpt_response(prompt, model)
        else:
            result = chat_completion(prompt, model)
    
        result = result.split("Answer:")[-1].strip(".").strip().lower()
        print(result)
        print(memory_answer)

        try:
            is_in_parametric = (result in memory_answer)
        except Exception as e:
            print(e)
            is_in_parametric = 0

        is_in_parametric.append(is_in_parametric)
        item['result'] = result
        item['is_in_parametric'] = is_in_parametric

        if is_in_parametric: # in parameter
            if result in ground_truth:
                item['is_correct'] = 1
            else:
                item['is_correct'] = 0



        data[index]=item
        print("current acc: ", np.mean(is_in_parametric))

        if index % 300 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)
    return prompts

# Example usage
question_type = "KA"
model = "o1-preview" # "gpt-4o-mini" or "Qwen_2_7B"

input_file = f'./data/ConflictQA/conflictQA-popQA-{model}_new_entity_separate_relation.json'
output_file = f'./output/ConflictQA/{question_type}-conflictQA-popQA-{model}-whole.json'
output_file_golden = f'./output/ConflictQA/{question_type}-conflictQA-popQA-{model}-golden.json'

whole_evaluation(input_file, output_file, model)
golden_evaluation(input_file, output_file_golden, model)