import openai
from langchain.llms import AzureOpenAI
import os
from langchain.callbacks import get_openai_callback
import json
from langchain.prompts import PromptTemplate
import re
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
import threading
import time

import argparse
from utils import (
    prepare_llm,
    process_data_and_extract_equations,
    extract_final_answer,
    create_computational_graph,
    compute_graph_values,
)
from prompt import template_graph_construct

parser = argparse.ArgumentParser()
parser.add_argument(
    "--model_name",
    default="gpt-4-1106",
    choices=["gpt-4-1106"],
    type=str,
)
parser.add_argument(
    "--temperature",
    default=0.7,
    type=float,
)
parser.add_argument(
    "--max_token",
    default=2048,
    type=float,
)
parser.add_argument(
    "--output_file",
    default="test_mapping.json",
    type=str,
)
parser.add_argument(
    "--orginal_file",
    default="test.jsonl",
    type=str,
)
args = parser.parse_args()

llm = prepare_llm(
    model_name=args.model_name,
    engine=args.model_name,
    max_tokens=args.max_token,
    temperature=args.temperature,
    top_p=0.95,
)

if os.path.exists(args.output_file):
    output_data = []
    with open(args.output_file, "r") as file:
        for line in file:
            data = json.loads(line)
            output_data.append(data)
else:
    output_data = []
    test_data = []
    with open(args.orginal_file, "r") as file:
        for line in file:
            data = json.loads(line)
            test_data.append(data)
    for data in test_data[0:100]:
        output_data.append({"Original": data, "Mapping": {}})

template = template_graph_construct
response_schemas = [
    ResponseSchema(
        name="Mapping", description="the dictionary of the mapping", type="dictionary"
    ),
]

temp = {
    "Mapping": {
        "Equation1": {
            "content": "2000 - 1800 = 200",
            "operator 1": {"Name": "A", "type": "initial", "value": 2000},
            "operator 2": {"Name": "B", "type": "initial", "value": 1800},
            "result": {"Name": "C", "type": "intermediate", "value": 200},
        },
        "Equation2": {
            "content": "200 / 250 = 4/5",
            "operator 1": {"Name": "C", "type": "intermediate", "value": 200},
            "operator 2": {"Name": "D", "type": "initial", "value": 250},
            "result": {"Name": "E", "type": "intermediate", "value": 0.8},
        },
        "Equation3": {
            "content": "300 / 5 = 60",
            "operator 1": {"Name": "F", "type": "initial", "value": 300},
            "operator 2": {"Name": "G", "type": "initial", "value": 5},
            "result": {"Name": "H", "type": "intermediate", "value": 60},
        },
        "Equation4": {
            "content": "60 * 4/5 = 48",
            "operator 1": {"Name": "H", "type": "intermediate", "value": 60},
            "operator 2": {"Name": "E", "type": "intermediate", "value": 0.8},
            "result": {"Name": "I", "type": "final", "value": 48},
        },
    }
}


output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
    template=template,
    input_variables=["question", "multiple_equation"],
    partial_variables={"format_instructions": format_instructions},
)
edge_case_index = [11, 15, 86, 66, 60, 75, 84]
with get_openai_callback() as cb:
    for i, data in enumerate(reversed(output_data[0:100])):
        if len(data["Mapping"].keys()) != 0 or i in edge_case_index:
            print("skip")
            continue
        print("index:" + str(i))
        multiple_equation, temp_answer = process_data_and_extract_equations(
            data["Original"]
        )
        temp_data = {"question": data["Original"]["question"], "answer": temp_answer}
        final_answer = float(extract_final_answer(temp_data))
        _input = prompt.format_prompt(
            question=temp_data, multiple_equation=multiple_equation
        )
        result = llm.invoke(_input.to_string())
        print("***********")
        print(result)
        print("***********")
        parsed_output = output_parser.parse(result)["Mapping"]

        graph = create_computational_graph(parsed_output)
        for node_name, node_attribute in graph["nodes"].items():
            if node_attribute["type"] != "initial":
                graph["nodes"][node_name].pop("value")
        flag = True
        try:
            values = compute_graph_values(graph)
        except Exception as e:
            flag = False
            data["Mapping"] = {}
            print("Failed")
        if values == None:
            print("time out")
            flag = False
            data["Mapping"] = {}
        if flag:
            for node_name, value in values.items():
                graph["nodes"][node_name]["value"] = value
            for node_name, node_attribute in graph["nodes"].items():
                if node_attribute["type"] == "final":
                    if float(node_attribute["value"]) == final_answer:
                        data["Mapping"] = parsed_output
                        print("Succeed")
                        data["Original"]["answer"] = temp_answer
                        break
                    else:
                        print("Not Match")
                        data["Mapping"] = {}
                        break
        with open(args.output_file, "w") as f:
            for item in output_data:
                json_item = json.dumps(item)
                f.write(json_item + "\n")
    print(cb)
