from langchain_core.messages import HumanMessage, AIMessage

import numpy as np
import networkx as nx


def insert_message_into_llm(llm, response, message):
    llm.prompt.template.messages.append(AIMessage(content=response["ai_message"].content))
    llm.prompt.template.messages.append(HumanMessage(
        content=message))
    return llm


def get_node_depth(node_edge, node_input):
    G = nx.DiGraph()
    G.add_edges_from(node_edge)
    root_nodes = [0, 7, 14, 21, 28]
    distances_list = []
    depth = 100
    for root_node in root_nodes:
        if root_node not in G.nodes:
            continue
        distances = nx.single_source_shortest_path_length(G, root_node)
        distances_list.append(distances)
        if node_input in distances:
            depth = distances[node_input]
    nodes_one_depth_lesses = []
    for distances in distances_list:
        nodes_one_depth_less = [node for node, distance in distances.items() if distance == depth - 1]
        nodes_one_depth_lesses.extend(nodes_one_depth_less)
    return nodes_one_depth_lesses


def get_prob_correctness(response: dict) -> float:
    logprobs = response["ai_message"].response_metadata["logprobs"]["content"]
    accuracy_prob = 0
    accuracy_check = False
    accuracy_check_in = False
    for content_before, content, content_after in zip(logprobs, logprobs[1:], logprobs[2:]):
        if (content["token"] == "accuracy" and "\"" in content_before["token"] and "\"" in content_after["token"]) or accuracy_check_in:
            accuracy_check_in = True
            if content["token"].isdigit():
                print(f"content: {content}")
                accuracy_prob = np.exp(content["logprob"])
                accuracy_check = True
            if accuracy_check:
                break
    return accuracy_prob
