import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from utils.utils import *
from llm.avior_api import *
from llm.gpt import *
from prompt.ALCUNA_prompt import *
import numpy as np
from tqdm import tqdm


instruction_template_dict = {
    "KA": KA_PROMPT,
    "KA_TRUST_OWN": KA_TRUST_OWN_PROMPT,
    "KA_GOLDEN_INT": KA_GOLDEN_INT_PROMPT,
    "KU": KU_PROMPT,
    "KD": KD_PROMPT,
    "INT": INT_PROMPT,
    "INT_CONTEXT": INT_CONTEXT_PROMPT,
    "CONFLICT": CONFLICT_PROMPT,
    "CK_PK": PKCK_KA_PROMPT,
}


def generate_prompt(taxon_info, question, question_type, know_triple_list=None):
    """
    Generate a prompt based on the given taxon information and question.
    
    Args:
    taxon_info (str): The taxon information to be inserted into the prompt.
    question (str): The question to be inserted into the prompt.
    
    Returns:
    str: The formatted prompt.
    """
    
    instruction_template = instruction_template_dict[question_type]
 
    taxon_info_str = json.dumps(taxon_info, indent=2) if isinstance(taxon_info, dict) or  isinstance(taxon_info, list) else str(taxon_info)
    question_str = question if isinstance(question, str) else str(question)
    if know_triple_list:
        know_triple_list_str = json.dumps(know_triple_list, indent=2) if isinstance(know_triple_list, dict) or  isinstance(know_triple_list, list) else str(know_triple_list)

    if question_type == "INT" or question_type == "CK_PK":
        prompt = instruction_template.replace("[Question]", question_str)
    else:
        prompt = instruction_template.replace("[Taxon]", taxon_info_str).replace("[Question]", question_str)

    if "[KNOWLEDGE]" in prompt:
        prompt = prompt.replace("[KNOWLEDGE]", know_triple_list_str)

    return prompt


def evaluation(input_file, output_file, question_type, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    acc = []
    prompts = []
    
    for index, item in enumerate(tqdm(data[:3200])):
        question = item.get('question')
        taxon_info = item.get('new_knowledge')
        answer = item.get('answers')
        is_acc = 0
        answer = answer[0]

        if question and taxon_info:
            prompt = generate_prompt(taxon_info, question, "KA")
            prompts.append(prompt)
            if "gpt" in model:
                result = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)
        
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            print(result)
            print(answer)

            try:
                is_acc = (result == answer)
            except Exception as e:
                print(e)
                is_acc = 0

            acc.append(is_acc)
            item['result'] = result
            item['acc'] = is_acc
            data[index]=item
            print("current acc: ", np.mean(acc))

        if index % 300 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)
    return prompts


def whole_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file("./output/ALCUNA/w/unk/KA_o1-preview_whole_old.json")

    acc_tokp_list = []
    unknown_tokp_list = []
    is_acc_np_list = []
    is_unknown_np_list = []
    is_acc_dr_list = []
    is_unknown_dr_list = []
    is_acc_golden_list = []
    is_unknown_golden_list = []

    total_price = 0
    price = 0
    
    for index, item in enumerate(tqdm(data[:3200])):
        question = item.get('question')+"\n\n4. Unknown"
        taxon_info = item.get('new_knowledge')
        answer = item.get('answers')
        meta_data = item.get('meta_data')

        is_acc_np = 0
        is_unknown_np = 0
        is_acc_dr = 0
        is_unknown_dr = 0
        is_acc_golden = 0
        is_unknown_golden = 0

        answer = answer[0]
        if "internal_knowledge_acc" in item.keys() and item['internal_knowledge_acc'] == 1:
            if "is_acc_ck_np" in item.keys():
                is_acc_np_list.append(item['is_acc_ck_np'])
                is_unknown_np_list.append(item['is_unknown_ck_np'])
                acc_tokp_list.append(item['is_acc_ck_tokp'])
                unknown_tokp_list.append(item['is_unknown_ck_tokp'])
                is_acc_golden_list.append(item['acc_golden_int'])
                is_unknown_golden_list.append(item['unknown_golden_int'])
                is_acc_dr_list.append(item['no_context_acc'])
                is_unknown_dr_list.append(item['no_context_acc_unknwon'])
                continue
            
            # neutral prompt, ck in context
            prompt = generate_prompt(taxon_info, question, "KA")
            if "o1" in model:
                result, price = generate_o1_response(prompt, model)
            elif "gpt" in model:
                result, price = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)
            print(prompt)
            print(result)
            print(answer)
            total_price += price
            print("current total price: ", total_price)
            item['np_result'] = result
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]
            try:
                is_acc_np = (answer in result)
                is_unknown_np = ("4" in result)
            except Exception as e:
                print(e)
                is_acc_np = 0

            is_acc_np_list.append(is_acc_np)
            is_unknown_np_list.append(is_unknown_np)
            item['is_acc_ck_np'] = is_acc_np
            item['is_unknown_ck_np'] = is_unknown_np

            # trust own prompt, ck in context
            is_acc_tokp = 0
            is_unknown_tokp = 0
            prompt = generate_prompt(taxon_info, question, "KA_TRUST_OWN")
            if "o1" in model:
                result, price = generate_o1_response(prompt, model)
            elif "gpt" in model:
                result, price = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)

            item['tokp_result'] = result
            print(prompt)
            print(result)
            print(answer)
            total_price += price
            print("current total price: ", total_price)
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]

            try:
                is_acc_tokp = (answer in result)
                is_unknown_tokp = ("4" in result)
            except Exception as e:
                print(e)
                is_acc_tokp = 0

            acc_tokp_list.append(is_acc_tokp)
            unknown_tokp_list.append(is_unknown_tokp)

            item['is_acc_ck_tokp'] = is_acc_tokp
            item['is_unknown_ck_tokp'] = is_unknown_tokp

            # golden knowledge
            hop_triplets = meta_data.get('hop_triplets', [])
            related_property = meta_data.get('related_property', {})
            related_name = related_property.get('name')
            related_values = related_property.get('values', [])

            know_triple_list = [] 
            for triplet in hop_triplets:
                subject, relation, object_ = triplet
                
                # Skip triplet if it matches the related_property
                if relation == related_name and object_ in related_values:
                    continue
                know_triple_list.append(triplet)

            prompt = generate_prompt(taxon_info, question, "KA_GOLDEN_INT", know_triple_list)

            if "o1" in model:
                result, price = generate_o1_response(prompt, model)
            elif "gpt" in model:
                result, price = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)

            print(prompt)
            print(result)
            print(answer)
            total_price += price
            print("current total price: ", total_price)
            item['result_golden_int'] = result
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]

            try:
                is_acc_golden = (answer in result)
                is_unknown_golden = ("4" in result)
            except Exception as e:
                print(e)
                is_acc_golden = 0

            is_acc_golden_list.append(is_acc_golden)
            is_unknown_golden_list.append(is_unknown_golden)
            item['acc_golden_int'] = is_acc_golden
            item['unknown_golden_int'] = is_unknown_golden

            # direct ask prompt
            prompt = generate_prompt("", question, "CK_PK")
            if "o1" in model:
                result, price = generate_o1_response(prompt, model)
            elif 'gpt' in model:
                result, price = generate_chatgpt_response(prompt, model)
                print(result)
            else:
                result = chat_completion(prompt, model)

            print(prompt)
            print(result)
            print(answer)
            total_price += price
            print("current total price: ", total_price)
            item['no_context_result'] = result
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]
            try:
                is_unknown_dr = ("4" in result)
                is_acc_dr = (answer in result)
            except Exception as e:
                print(e)
                is_acc_dr = 0

            is_acc_dr_list.append(is_acc_dr)
            is_unknown_dr_list.append(is_unknown_dr)

            item['no_context_acc'] = is_acc_dr
            item['no_context_acc_unknwon'] = is_unknown_dr

            data[index] = item

            print("current direct ask acc: ", np.mean(is_acc_dr_list))
            print("current direct askunknown rate: ", np.mean(is_unknown_dr_list))

            print("current neutral prompt acc: ", np.mean(is_acc_np_list))
            print("current neutral prompt unknown: ", np.mean(is_unknown_np_list))

            print("current trust own prompt acc: ", np.mean(acc_tokp_list))
            print("current trust own prompt unknown: ", np.mean(unknown_tokp_list))

            print("current golden acc: ", np.mean(is_acc_golden_list))
            print("current golden unknown rate: ", np.mean(is_unknown_golden_list))

        if index % 20 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)

def whole_metrics(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(output_file)

    acc_tokp_list = []
    unknown_tokp_list = []
    is_acc_np_list = []
    is_unknown_np_list = []
    is_acc_dr_list = []
    is_unknown_dr_list = []
    is_acc_golden_list = []
    is_unknown_golden_list = []

    
    for index, item in enumerate(tqdm(data[:3200])):
        question = item.get('question')+"\n\n4. Unknown"
        taxon_info = item.get('new_knowledge')
        answer = item.get('answers')
        meta_data = item.get('meta_data')

        is_acc_np = 0
        is_unknown_np = 0
        is_acc_dr = 0
        is_unknown_dr = 0
        is_acc_golden = 0
        is_unknown_golden = 0

        answer = answer[0]

        if "internal_knowledge_acc" in item.keys() and item['internal_knowledge_acc'] == 1:
            # neutral prompt, ck in context
            
            result = item['np_result'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]
            try:
                is_acc_np = (answer in result)
                is_unknown_np = ("4" in result)
            except Exception as e:
                print(e)
                is_acc_np = 0

            is_acc_np_list.append(is_acc_np)
            is_unknown_np_list.append(is_unknown_np)

            # trust own prompt, ck in context
            is_acc_tokp = 0
            is_unknown_tokp = 0
            result = item['tokp_result'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]

            try:
                is_acc_tokp = (answer in result)
                is_unknown_tokp = ("4" in result)
            except Exception as e:
                print(e)
                is_acc_tokp = 0

            acc_tokp_list.append(is_acc_tokp)
            unknown_tokp_list.append(is_unknown_tokp)


            print("current direct ask acc: ", np.mean(is_acc_dr_list))
            print("current direct askunknown rate: ", np.mean(is_unknown_dr_list))

            print("current neutral prompt acc: ", np.mean(is_acc_np_list))
            print("current neutral prompt unknown: ", np.mean(is_unknown_np_list))

            print("current trust own prompt acc: ", np.mean(acc_tokp_list))
            print("current trust own prompt unknown: ", np.mean(unknown_tokp_list))
            print("length ", len(unknown_tokp_list))

            # direct ask prompt
            if 'no_context_result' not in item.keys():
                continue
            result = item['no_context_result'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]
            try:
                is_unknown_dr = ("4" in result)
                is_acc_dr = (answer in result)
            except Exception as e:
                print(e)
                is_acc_dr = 0

            is_acc_dr_list.append(is_acc_dr)
            is_unknown_dr_list.append(is_unknown_dr)

            # golden knowledge
            if 'result_golden_int' not in item.keys():
                continue

            result = item['result_golden_int'].split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
               result = result[0]

            try:
                is_acc_golden = (answer in result)
                is_unknown_golden = ("4" in result)
            except Exception as e:
                print(e)
                is_acc_golden = 0

            is_acc_golden_list.append(is_acc_golden)
            is_unknown_golden_list.append(is_unknown_golden)

            print("current golden acc: ", np.mean(is_acc_golden_list))
            print("current golden unknown rate: ", np.mean(is_unknown_golden_list))


def trust_own_evaluation(input_file, output_file, question_type, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    acc = []

    for index, item in enumerate(tqdm(data[:3200])):
        if "internal_knowledge_acc" in item.keys() and item["internal_knowledge_acc"] == 1:
            question = item.get('question')
            taxon_info = item.get('new_knowledge')
            answer = item.get('answers')
            is_acc = 0

            answer = answer[0]
            prompt = generate_prompt(taxon_info, question, "KA_TRUST_OWN")

            if "gpt" in model:
                result = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)

            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()

            print(result)
            print(answer)

            try:
                is_acc = (result == answer)
            except Exception as e:
                print(e)
                is_acc = 0

            acc.append(is_acc)
            item['result_trust_own'] = result
            item['acc_trust_own'] = is_acc
            data[index]=item
            print("current acc: ", np.mean(acc))

        if index % 100 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)


def translate_triplets_to_boolean_question(meta_data):
    hop_triplets = meta_data.get('hop_triplets', [])
    related_property = meta_data.get('related_property', {})
    
    related_name = related_property.get('name')
    related_values = related_property.get('values', [])

    bool_q_list = [] 
    for triplet in hop_triplets:
        subject, relation, object_ = triplet
        
        # Skip triplet if it matches the related_property
        if relation == related_name and object_ in related_values:
            continue

        question = f"Is the {relation} of {subject} {json.dumps(object_)}?"

        # Construct the prompt for the chat model
        bool_q_list.append(question)
        
    return bool_q_list


def internal_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    acc = []
    prompts = []
    all_acc = []
    total_price = 0
    price = 0
    print("length of data: ", len(data))
    filtered_data = read_json_file("./output/ALCUNA/KA_o1-preview_internal_old.json")
    for index, item in enumerate(tqdm(data[1210:3200])):
        taxon_info = item.get('new_knowledge')
        meta_data = item.get('meta_data')
        sub_question_acc = {}
        internal_knowledge_acc = 1

        bool_question_list = translate_triplets_to_boolean_question(meta_data)
    
        for que in bool_question_list:
            prompt = generate_prompt(taxon_info, que, "INT")
            prompts.append(prompt)

            if 'o1' in model:
                result, price = generate_o1_response(prompt, model)
            elif 'gpt' in model:
                result, price = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)
            
            total_price += price
            print("current total price: ", total_price)
            print(prompt)
            print(result)

            result = result.split("Final Answer:")[-1].lower().strip().strip(".").strip("]").strip("[").strip(".").strip()

            if result != "yes":
                sub_question_acc[que] = 0
                internal_knowledge_acc = 0
            else:
                sub_question_acc[que] = 1

            acc.append(sub_question_acc[que])

        all_acc.append(internal_knowledge_acc)

        item['internal_knowledge_acc'] = internal_knowledge_acc
        item['sub_question_acc'] = sub_question_acc

        print("current internal_knowledge_acc: ", np.mean(acc))
        print("current all internal_knowledge_acc: ", np.mean(all_acc))

        if internal_knowledge_acc == 1:
            filtered_data.append(item)

        if index % 20 == 0:
            write_json_file(filtered_data, output_file)

        if len(filtered_data) > 668:
            break

    write_json_file(filtered_data, output_file)


def ckpk_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    unknown_acc = []
    all_context_acc = []

    for index, item in enumerate(tqdm(data[:3200])):
        meta_data = item.get('meta_data')
        question = item.get('question')
        answer = item.get('answers')
        answer = answer[0]
        is_acc = 1
        is_unknown = 0

        if 'internal_knowledge_acc' in item.keys() and item['internal_knowledge_acc'] == 1:
            prompt = generate_prompt("", question+"\n\n4. Unknown", "CK_PK")
            if 'gpt' in model:
                result = generate_chatgpt_response(prompt, model)
                print(result)
            else:
                result = chat_completion(prompt, model)
            print(prompt)
            print(result)
            print(answer)

            item['no_context_result'] = result
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()     
            if len(result) > 1:
                result = result[0]

            try:
                is_unknown = ("4" in result)
                is_acc = (answer in result)
            except Exception as e:
                print(e)
                is_acc = 0

            all_context_acc.append(is_acc)
            unknown_acc.append(is_unknown)

            item['no_context_acc'] = is_acc
            item['no_context_acc_unknwon'] = is_unknown

            data[index] = item

            print("current no ckpk acc: ", np.mean(all_context_acc))
            print("current unknown rate: ", np.mean(unknown_acc))
            
        if index % 300 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)



def internal_context_knowledge_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    context_acc = []
    all_context_acc = []

    for index, item in enumerate(tqdm(data[:3200])):
        meta_data = item.get('meta_data')
        
        sub_question_context_acc = {}
        in_context_internal_knowledge = 1

        if 'internal_knowledge_acc' in item.keys() and item['internal_knowledge_acc'] == 1:
            sub_question_dict = item.get('sub_question_acc')

            # check internal knowledge
            for que, internal_score in sub_question_dict.items():
                prompt = generate_prompt(meta_data['hop_triplets'], que, "INT_CONTEXT")

                if 'gpt' in model:
                    result = generate_chatgpt_response(prompt, model)
                    print(result)
                else:
                    result = chat_completion(prompt, model)

                result = result.split("Final Answer:")[-1].strip(".").strip().lower()

                if result != "yes":
                    sub_question_context_acc[que] = 0
                    in_context_internal_knowledge = 0
                else:
                    sub_question_context_acc[que] = 1

                context_acc.append(sub_question_context_acc[que])
            all_context_acc.append(in_context_internal_knowledge)

            item['in_context_internal_knowledge'] = in_context_internal_knowledge
            item['sub_question_context_acc'] = sub_question_context_acc
            data[index] = item

            print("current internal_knowledge_acc: ", np.mean(context_acc))
            print("current all internal_knowledge_acc: ", np.mean(all_context_acc))
            
        if index % 300 == 0:
            write_json_file(data, output_file)

    write_json_file(data, output_file)



def golden_internal_evaluation(input_file, output_file, model):
    """
    Process the JSON file to generate prompts and write them to an output file.
    
    Args:
    input_file (str): The path to the input JSON file containing questions and taxon information.
    output_file (str): The path to the output file where prompts will be saved.
    """
    data = read_json_file(input_file)
    acc = []
    unknown = []

    for index, item in enumerate(tqdm(data[:3200])):
        if 'internal_knowledge_acc' in item.keys() and item['internal_knowledge_acc'] == 1:
            question = item.get('question') + "\n\n4. Unknown"
            taxon_info = item.get('new_knowledge')
            answer = item.get('answers')
            meta_data = item.get('meta_data')
            is_acc = 0
            answer = answer[0]
            is_unknown = 0

            hop_triplets = meta_data.get('hop_triplets', [])
            related_property = meta_data.get('related_property', {})
            related_name = related_property.get('name')
            related_values = related_property.get('values', [])

            know_triple_list = [] 
            for triplet in hop_triplets:
                subject, relation, object_ = triplet
                
                # Skip triplet if it matches the related_property
                if relation == related_name and object_ in related_values:
                    continue
                know_triple_list.append(triplet)
            prompt = generate_prompt(taxon_info, question, "KA_GOLDEN_INT", know_triple_list)

            if "gpt" in model:
                result = generate_chatgpt_response(prompt, model)
            else:
                result = chat_completion(prompt, model)

            print(prompt)
            print(result)
            print(answer)
            item['result_golden_int'] = result
            result = result.split("Final Choice:")[-1].lower().strip().strip("unknown").strip().strip(".").strip("]").strip("[").strip(".").strip()
            if len(result) > 1:
                result = result[0]
            try:
                is_acc = (answer in result)
                is_unknown = ("4" in result)
            except Exception as e:
                print(e)
                is_acc = 0

            acc.append(is_acc)
            unknown.append(is_unknown)
            item['acc_golden_int'] = is_acc
            item['unknown_golden_int'] = is_unknown
            data[index]=item
            print("current acc: ", np.mean(acc))
            print("current unknown rate: ", np.mean(unknown))

            if index % 100 == 0:
                write_json_file(data, output_file)

    write_json_file(data, output_file)


# Example usage
question_type = "KA" # KU or KA
model = "o1-preview" # "gpt-4o-mini" or "LLAMA_3_70B" or "Qwen_2_7B"

input_file = f'./data/ALCUNA/questions_{question_type}.json'
output_file = f'./output/ALCUNA/w/unk/{question_type}_{model}.json'

internal_evaluation(input_file, output_file, model)
# whole_evaluation(internal_output_file, output_file_whole, model)