import sys
import os,csv
from tqdm import tqdm
import random

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from llm.avior_api import chat_completion
from llm.gpt import *
import json
from utils.utils import *
from prompt.OpenbookQA_prompt import *

instruction_template_dict = {
    "EXTRACT_KNOW": EXTRACT_KNOWLEDGE_PROMPT,
    "CONSTRUCT_KNOW": CONSTRUCT_KNOW_PROMPT,
    "CONSTRUCT_CON": CONSTRUCT_CONFLICT_KNOW_PROMPT,
}


def generate_prompt(data_item, prompt_format="EXTRACT_KNOW"):
    # Extract question details
    question_stem = data_item['question']['stem']
    choices = data_item['question']['choices']

    # Format choices
    formatted_choices = "\n".join([f"{choice['label']}. {choice['text']}" for choice in choices])
    
    # Create the prompt
    if prompt_format == "EXTRACT_KNOW":
        prompt = instruction_template_dict[prompt_format].replace("[question_stem]",question_stem).replace("[formatted_choices]",formatted_choices)
    elif prompt_format == "CONSTRUCT_KNOW":
        fact = data_item['internal_knowledge']
        sci_fact = data_item['fact1']
        if type(fact) == str:
            fact = [fact]
        fact.append(sci_fact)
        prompt = instruction_template_dict[prompt_format].replace("[question_stem]",question_stem).replace("[formatted_choices]",formatted_choices).replace("[commonsense_knowledge]",str(fact)).replace("[original_answer_key]", data_item['answerKey'])
    elif prompt_format == "CONSTRUCT_CON":
        fact = data_item['internal_knowledge']
        sci_fact = data_item['fact1']
        fact.append(sci_fact)
        prompt = instruction_template_dict[prompt_format].replace("[Taxon]",str(fact))
    return prompt


def extract_all_knowledge(model):
    data = read_jsonl_file("./data/openbookQA/test_complete.jsonl")
    filtered_data = []
    total_price = 0
    price = 0
    print("length of data", len(data))
    for index, item in enumerate(tqdm(data)):
        prompt = generate_prompt(item, "EXTRACT_KNOW")
        i=0
        while i < 5:
            try:
                if 'o1' in model:
                    response, price = generate_o1_response(prompt, model)
                elif 'gpt' in model:
                    response, price = generate_chatgpt_response(prompt, model)
                else:
                    response = chat_completion(prompt, model)

                total_price += price

                print(prompt, response)

                # Post-process the response to extract extra knowledge and answer
                knowledge_and_answer = response.split("Final Answer Choice:")
                extra_knowledge = eval(knowledge_and_answer[0].strip("Knowledge:").strip())

                # Extract answer
                answer = knowledge_and_answer[1].strip() if len(knowledge_and_answer) > 1 else ''
                break
            except Exception as e:
                print(e)
                i+=1
                continue    

        if item['answerKey'] == answer:
            item["internal_knowledge"] = extra_knowledge
            item["answer"] = answer
            filtered_data.append(item)

        print(f"Total Price: {total_price}")
        if index % 20 == 0:
            write_json_file(filtered_data, f"./data/openbookQA/test_complete_internal_{model}.json")
    write_json_file(filtered_data, f"./data/openbookQA/test_complete_internal_{model}.json")

def generate_question(model):
    gen_question = []
    data = read_json_file(f"./data/openbookQA/test_complete_internal_{model}.json")
    for index, item in enumerate(tqdm(data)):
        if item['answer'] != item['answerKey']:
            continue
        prompt = generate_prompt(item, "CONSTRUCT_KNOW")
        i=0
        while i<5:
            try:
                # if 'gpt' in model:
                #     response = generate_chatgpt_response(prompt, model)
                # else:
                response = chat_completion(prompt, "LLAMA_3_70B")
                # response, price = generate_chatgpt_response(prompt, 'gpt-4o')
                print(prompt)
                print("xxxxxxxxxxx\n", response)
                # total_price += price
                print(f"Total Price: {total_price}")
                # Extract response
                response_text = response
                
                # Extract extra knowledge
                new_answer = response_text.split("New Answer Key:")[-1].strip()
                new_knowledge = eval(response_text.split("New Knowledge:")[-1].split("New Answer Key:")[0].strip("New Knowledge:").strip())
                new_question = response_text.split("New Question:")[-1].split("New Choices:")[0].strip("New Question:").strip()
                new_choices = response_text.split("New Choices:")[-1].split("New Knowledge:")[0].strip()
                break
            except:
                i+=1
                continue 
        item["new_question"] = new_question
        item["new_choices"] = new_choices
        item["new_knowledge"] = new_knowledge
        item["answer"] = new_answer
        gen_question.append(item)

        if index % 50 == 0:
            write_json_file(gen_question, f"./data/openbookQA/test_complete_{model}_complement.json")

    write_json_file(gen_question, f"./data/openbookQA/test_complete_{model}_complement.json")


def generate_conflict_questions(model):
    gen_question = []
    data = read_json_file(f"./data/openbookQA/test_complete_internal_{model}.json")
    for index, item in enumerate(tqdm(data)):

        prompt = generate_prompt(item, "CONSTRUCT_CON")
        i=0
        while i<5:
            try:
                # if 'gpt' in model:
                #     response = generate_chatgpt_response(prompt, model)
                # else:
                response = chat_completion(prompt, "LLAMA_3_70B")
                # response,price = generate_chatgpt_response(prompt, 'gpt-4o')
                print(prompt)
                print("xxxxxxxxxxx\n", response)
                # total_price += price
                print(f"Total Price: {total_price}")

                # Extract response
                response_text = response
                
                # Extract extra knowledge
                conflict_knowledge = response_text.split("Conflicting Knowledge:")[-1].split("Question:")[0].strip("Conflicting Knowledge:").strip()
                new_question = response_text.split("Question:")[-1].split("Let's think step by step.")[0].strip("Question:").strip()
                break
            except:
                i+=1
                continue 
        item["new_question"] = new_question
        item["conflict_knowledge"] = conflict_knowledge
        gen_question.append(item)

        if index % 50 == 0:
            write_json_file(gen_question, f"./data/openbookQA/test_complete_harder_{model}_conflict.json")

    write_json_file(gen_question, f"./data/openbookQA/test_complete_harder_{model}_conflict.json")

  
model = "o1-preview"
total_price = 0

extract_all_knowledge(model)
generate_question(model)
generate_conflict_questions(model)