from collections import Counter

from src.util.reducers import *
from src.util.utils import insert_message_into_llm, get_prob_correctness, get_node_depth
from langchain_core.messages import HumanMessage, AIMessage

import numpy as np
import networkx as nx


def get_method(base_prompt, base_llm):
    def method(state):
        method_prompt = base_prompt()
        llm = base_llm(method_prompt, temperature=0)
        input_state = dict()
        llm_res = llm.invoke(input_state)
        return llm_res
    return method


def get_answering(node_name, parent_node_name, base_prompt, base_llm, input_variables):
    def answering(state):
        answering_prompt = base_prompt()
        if node_name == parent_node_name or state["temperature"][parent_node_name] == 0:
            temperature = 1
        else:
            temperature = state["temperature"][parent_node_name]

        llm = base_llm(answering_prompt, temperature=temperature)
        input_state = dict()
        input_state["data"] = state["data"]
        input_state["method"] = state["method"]

        if node_name != parent_node_name:
            up_node_list = get_node_depth(state["node_graph"], node_name)
            filtered_dict = {k: state["answer_probs"][k] for k in up_node_list if k in state["answer_probs"]}
            sorted_items = sorted(filtered_dict.items(), key=lambda item: item[1], reverse=True)
            top_k_keys = [key for key, value in sorted_items]

            input_state["aggregate_rationale"] = []
            input_state["inaccurate_info"] = []
            input_state["answer_prob"] = []
            for parent_node in top_k_keys:
                if parent_node == parent_node_name:
                    continue
                if state["answer_probs"][parent_node] == 0:
                    continue
                input_state["aggregate_rationale"].append(state["aggregate_rationales"][parent_node])
                input_state["inaccurate_info"].append(state["inaccurate_infos"][parent_node])
                input_state["answer_prob"].append(state["answer_probs"][parent_node])
            input_state["aggregate_rationale"].append(state["aggregate_rationales"][parent_node_name])
            input_state["inaccurate_info"].append(state["inaccurate_infos"][parent_node_name])
            input_state["answer_prob"].append(state["answer_probs"][parent_node_name])
        input_state = unlist_state(input_state)
        new_state = dict()
        for _ in range(2):
            llm_res = llm.invoke(input_state)
            if not llm_res or "answer" not in llm_res:
                continue
            if node_name == parent_node_name:
                break
            if llm_res["answer"] == state["answers"][parent_node_name]:
                message = f"You response the same answer. Think again and answer with higher performance."
                llm = insert_message_into_llm(llm, llm_res, message)
                continue
            else:
                break
        new_state["answers"] = {node_name: llm_res["answer"]}
        new_state["answer_rationales"] = {node_name: llm_res["answer_rationale"]}
        return new_state
    return answering


def get_rationale_prob(node_name, parent_node_name, base_prompt, base_llm):
    def rationale_prob(state):
        get_rationale_prob_prompt = base_prompt()
        temperature = 0
        llm = base_llm(get_rationale_prob_prompt, temperature=temperature, logprobs=True)

        input_state = dict()
        input_state["data"] = state["data"]
        input_state["evaluate_methods"] = state["evaluate_methods"]
        input_state["answer"] = state["answers"][node_name]
        input_state = unlist_state(input_state)

        for i in range(2):
            llm_res = llm.invoke(input_state)
            if "accuracy" in llm_res:
                accuracy_prob = get_prob_correctness(llm_res)
                print(f"log_prob:{accuracy_prob}")
                if llm_res["accuracy"] > 90 and accuracy_prob < 0.99:
                    message = (f"You are not sure that you score that accuracy. You must give it an accuracy lower than 90."
                               f"Reconsider the accuracy_rationale you just mentioned and add the missing details to your answer.")
                    llm = insert_message_into_llm(llm, llm_res, message)
                elif llm_res["accuracy"] < 10 and accuracy_prob < 0.99:
                    message = (f"You are not sure that you score that accuracy. You must give it an accuracy higher than 10."
                               f"Reconsider the accuracy_rationale you just mentioned and add the missing details to your answer.")
                    llm = insert_message_into_llm(llm, llm_res, message)
                elif accuracy_prob < 0.5:
                    message = f"Your confidence level is only accuracy: {accuracy_prob}. Re-evaluate your low confidence score."
                    llm = insert_message_into_llm(llm, llm_res, message)
                else:
                    break
            else:
                print(f"llm_res doesn't include score.: {i+1}")
                continue

        prob = llm_res["accuracy"] / 100 * np.power(accuracy_prob, 1 / np.e)

        new_state = dict()
        new_state["answer_accuracies"] = {node_name: llm_res["accuracy"]}
        new_state["answer_probs"] = {node_name: prob}
        new_state["answer_accuracy_rationales"] = {node_name: llm_res["accuracy_rationale"]}
        return new_state
    return rationale_prob


def get_aggregate_rationales(node_name, parent_node_name, base_prompt, base_llm):
    def aggregate_rationales(state):
        aggregate_rationale_prompt = base_prompt()
        temperature = 0
        llm = base_llm(aggregate_rationale_prompt, temperature=temperature, logprobs=True)

        input_state = dict()
        input_state["answer_rationale"] = state["answer_rationales"][node_name]
        input_state["accuracy_rationale"] = state["answer_accuracy_rationales"][node_name]
        input_state = unlist_state(input_state)

        new_state = dict()
        new_state["aggregate_rationales"] = {node_name: ""}
        for i in range(2):
            llm_res = llm.invoke(input_state)
            if "aggregate_rationale" in llm_res:
                new_state["aggregate_rationales"] = {node_name: llm_res["aggregate_rationale"]}
                new_state["inaccurate_infos"] = {node_name: llm_res["inaccurate_info"]}
                break
            else:
                print(f"llm_res doesn't include score.: {i+1}")
                continue
        return new_state
    return aggregate_rationales


def get_aggregate_graph(base_prompt, base_llm, leaf_node):
    def aggregate_graph(state):
        aggregate_prompt = base_prompt()
        temperature = 0

        llm = base_llm(aggregate_prompt, temperature=temperature)
        input_state = dict()
        input_state["data"] = state["data"]
        input_state["method"] = state["method"]

        top_k_prob = []
        for node in leaf_node:
            top_k_prob.append([state["answer_probs"][node], state["answers"][node], state["aggregate_rationales"][node], state["inaccurate_infos"][node]])

        sorted_prob = sorted(top_k_prob, key=lambda x: x[0], reverse=True)

        input_state["answer_prob"] = [prob[0] for prob in sorted_prob if not prob[0] == 0]
        input_state["answer"] = [prob[1] for prob in sorted_prob if not prob[0] == 0]
        input_state["aggregate_rationale"] = [prob[2] for prob in sorted_prob if not prob[0] == 0]
        input_state["inaccurate_info"] = [prob[3] for prob in sorted_prob if not prob[0] == 0]

        input_state = unlist_state(input_state)

        for _ in range(10):
            llm_res = llm.invoke(input_state)
            if not llm_res or "answer" not in llm_res:
                continue
            break
        new_state = dict()
        new_state["answers"] = {"final": llm_res["answer"]}
        new_state["answer_rationales"] = {"final": llm_res["answer_rationale"]}

        sorted_data = sorted(state["data"])
        intersection = list((Counter(sorted_data) & Counter(llm_res["answer"])).elements())
        union = list((Counter(sorted_data) | Counter(llm_res["answer"])).elements())
        print(f"final acc: {100 * len(intersection) / len(union)}")
        new_state["for_test_acc"] = {"final": 100 * len(intersection) / len(union), "temperature": "fix_1"}
        return new_state

    return aggregate_graph
