import sys
import os,csv
from tqdm import tqdm
import random

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '.')))
from utils.utils import *

from llm.avior_api import *
from llm.gpt import *
from prompt.Musique_prompt import *


instruction_template_dict = {
    "CONSTRUCT_CON": CONSTRUCT_CONFLICT_KNOW_PROMPT,
}

def filter_internal_knowledge(input_file, output_file, model):
    data = read_json_file(input_file)
    num_2hop = 0
    num_3hop = 0
    num_4hop = 0
    new_data = []

    total_price = 0
    price = 0
    print("length of data", len(data))
    for index, item in enumerate(tqdm(data)):
        question_decomposition = item.get('question_decomposition')
        if 'verified' in question_decomposition[0].keys():
            continue
        for i in range(len(question_decomposition)):
            question_decomposition[i]['paragraph'] = item['paragraphs'][question_decomposition[i]['paragraph_support_idx']]
        if '2hop' in item['id']:
            num_2hop += 1
            if num_2hop>401:
                continue
        if '3hop' in item['id']:
            num_3hop+=1
            if num_3hop>200:
                continue
        if '4hop' in item['id']:
            num_4hop+=1
            if num_4hop>200:
                break

        for i in range(1, len(question_decomposition)):
            question_decomposition[i]['question'] = question_decomposition[i]['question'].replace("#"+str(i), question_decomposition[i]['answer'])

        for i in range(len(question_decomposition)):
            verify_prompt = "Verify the statement: " + question_decomposition[i]['question'] + " >> " + question_decomposition[i]['answer']+ ".\n\nYour output can only be 'true' or 'false' or 'unknown'. Do not output anything other than these three words."
            if 'o1' in model:
                result,price = generate_o1_response(verify_prompt, model)
            elif 'gpt' in model:
                result,price = generate_chatgpt_response(verify_prompt, model)
            else:
                result = chat_completion(verify_prompt, model)
            
            total_price += price
            print(verify_prompt)
            print(result)
            print("current total price: ", total_price)

            if "true" in result.lower().strip():
                question_decomposition[i]['verified'] = True
            else:
                question_decomposition[i]['verified'] = False

        new_data.append(item)
        if index % 20 == 0:
            write_json_file(new_data, output_file)
    write_json_file(new_data, output_file)


def filter_paragraph(input_file, output_file, model):
    data = read_json_file(input_file)
    new_file = []
    for index, item in enumerate(tqdm(data)):
        question_decomposition = item.get('question_decomposition')
        if item['answerable'] == True:
            for i in range(len(question_decomposition)):
                question_decomposition[i]['paragraph'] = item['paragraphs'][question_decomposition[i]['paragraph_support_idx']]
            new_file.append({
                "id": item.get('id'),
                "question":item.get('question'),
                "question_decomposition": question_decomposition,
                "answer": item.get('answer'),
                "answer_aliases": item.get('answer_aliases'),
            })
    if index % 20 == 0:
        write_json_file(new_file, output_file)
    write_json_file(new_file, output_file)

def filter_question_decomposition_is_supporting(input_file, output_file, model):
    data = read_json_file(input_file)
    filtered_data = []
    for index, item in enumerate(tqdm(data)):
        question_decomposition = item.get('question_decomposition')
        # for each decomposition in question_decomposition: if is_supporting field of question_decomposition has both true and false, we add this data item into filtered_data
        # print(item['question'])
        if 'verified' not in question_decomposition[0].keys():
            continue
        is_supporting = [decomposition['verified'] for decomposition in question_decomposition]
        # verify if true and false are both in is_supporting
        if True in is_supporting and False in is_supporting:
            filtered_data.append(item)
        print(len(filtered_data))
        if index % 20 == 0:
            write_json_file(filtered_data, output_file)
    write_json_file(filtered_data, output_file)

def filter_inner_external_knowledge(input_file, output_file, model):
    data = read_json_file(input_file)
    for index, item in enumerate(tqdm(data)):
        question_decomposition = item.get('question_decomposition')
        inner_knowledge = []
        external_knowledge = []
        for i in range(len(question_decomposition)):
            if question_decomposition[i]['verified'] == True:
                inner_knowledge.append(question_decomposition[i]['paragraph']['paragraph_text'])
            else:
                external_knowledge.append(question_decomposition[i]['paragraph']['paragraph_text'])
        item['inner_knowledge'] = inner_knowledge
        item['external_knowledge'] = external_knowledge

        if index % 20 == 0:
            write_json_file(data, output_file)
    write_json_file(data, output_file)


def generate_prompt(data_item, prompt_format="CONSTRUCT_CON"):
    # Extract question details
    question = data_item['question']
    
    # Create the prompt
    if prompt_format == "CONSTRUCT_CON":
        fact = data_item['inner_knowledge']

        prompt = instruction_template_dict[prompt_format].replace("[Taxon]", str(fact))
    return prompt


def generate_conflict_questions(inputfile, outputfile, model):
    gen_question = []
    data = read_json_file(inputfile)
    for index, item in enumerate(tqdm(data)):

        prompt = generate_prompt(item, "CONSTRUCT_CON")
        i=0
        while i<5:
            try:
                response = chat_completion(prompt, "LLAMA_3_70B")
                # response, price = generate_chatgpt_response(prompt, "gpt-4o-mini")
                print(prompt)
                print("xxxxxxxxxxx\n", response)
                # Extract response
                response_text = response
                
                # Extract extra knowledge
                conflict_knowledge = response_text.split("Conflicting Knowledge:")[-1].split("Question:")[0].strip("Conflicting Knowledge:").strip()
                new_question = response_text.split("Question:")[-1].split("Let's think step by step.")[0].strip("Question:").strip()
                break
            except:
                i+=1
                continue 
        item["new_question"] = new_question
        item["conflict_knowledge"] = conflict_knowledge
        gen_question.append(item)

        if index % 50 == 0:
            write_json_file(gen_question, outputfile)

    write_json_file(gen_question, outputfile)


# model = "gpt-4o-mini"
model = "o1-preview"
input_file = "./data/Musique/musique_ans_v1.0_dev.jsonl"
output_file = f"./data/Musique/musique_ans_v1.0_dev_verified_{model}.json"
output_file_is_supporting = f"./data/Musique/musique_ans_v1.0_dev_is_supporting_{model}.json"
output_file_filtered = f"./data/Musique/musique_dev_verified_{model}_filtered.json"
output_file_inner_external = f"./data/Musique/musique_dev_verified_{model}_inner_external.json"
output_file_conflict = f"./data/Musique/musique_dev_verified_{model}_conflict.json"

filter_internal_knowledge(input_file, output_file, model)
filter_paragraph(output_file, output_file_filtered, model)
filter_question_decomposition_is_supporting(output_file_filtered, output_file_is_supporting, model)
filter_inner_external_knowledge(output_file_is_supporting, output_file_inner_external, model)
generate_conflict_questions(output_file_inner_external, output_file_conflict,model)
