import pandas as pd
import networkx as nx
from networkx.drawing.nx_agraph import to_agraph
from IPython.display import SVG, display

def plot_graph(g, fp="test.svg", font_size=12):
    g = nx.relabel_nodes(g, lambda x: "<empty>" if x == "" else x)
    a = to_agraph(g)
    
    # Set default font size for nodes, edges, and graph labels
    a.node_attr.update(fontsize=str(font_size))
    a.edge_attr.update(fontsize=str(font_size))
    a.graph_attr.update(fontsize=str(font_size))
    
    a.layout("dot")
    a.draw(fp)

    display(SVG(filename=fp))
def truncate_string(s, max_length):
    return s if len(s) <= max_length else s[:max_length - 2] + '..'

def visualize_graph(G, verified_nodes_only = False, truncate_strings=True, max_length=20):
    if verified_nodes_only:
        good_nodes = [n for n, d in G.nodes(data=True) if G.out_degree(n) != 0 or d.get("verified",False)]
        P = nx.subgraph(G, good_nodes).copy()
    else:
        P = G.copy()

    solutions = [n for n in G.nodes() if G.nodes[n].get("is_solution", False)]
    print(len(solutions))
    
    root_style = dict(shape="box", style="filled", fillcolor="gray")
    root_solved_style = dict(fillcolor="purple")

    #proposal_style = dict(shape="ellipse")
    unverified_style = dict(color="#00000020", style="dashed", fontcolor="red")
    verified_style = dict(shape="ellipse", style="filled", fillcolor="lightgreen")
    finished_style = dict(shape="ellipse", style="filled", fillcolor="orange")
    solution_style = dict(shape="ellipse", style="filled", fillcolor="lightblue")

    error_edge_style = dict(color="#FF000020",style='dashed',penwidth=3)
    good_edge_style = dict(color='green',style='solid',penwidth=2)
    shortcut_style = dict(color="pink",style='dashed',penwidth=3)
    # style edges
    for s, t, d in P.edges(data=True):
        if "label" in d: d["label"] = truncate_string(d["label"], max_length)

        if d.get("is_error", False):
            d.update(error_edge_style)
        else:
            d.update(good_edge_style)
        if d.get("is_shortcut", False):
            d.update(shortcut_style)

    for n, d in P.nodes(data=True):
        if "label" not in d:
            d["label"] = str(n)
        if "label" in d: d["label"] = truncate_string(d["label"], max_length)
        if "expedience" in d:
            d["label"] = d["label"] + "\n{x:.2f}".format(x=d["expedience"])

        if not d.get("verified", False):
            P.add_node(n, **unverified_style)
        else:
            P.add_node(n, **verified_style)

        if d.get("is_finished", False):
            P.add_node(n, **finished_style)

        if d.get("is_solution", False):
            P.add_node(n, **solution_style)
        
        if d.get("is_root", False):
            P.add_node(n, **root_style)
            if any(nx.has_path(G, n, l) for l in solutions):
                P.add_node(n, **root_solved_style)
    plot_graph(P)


###############################################################################################################################################################


from collections import deque

def reachable_leaves(G, sources):
    visited = set()
    leaves = set()
    queue = deque(sources)

    while queue:
        node = queue.popleft()
        if node in visited:
            continue
        visited.add(node)
        children = list(G.successors(node))
        if not children:
            leaves.add(node)
        else:
            queue.extend(children)
    return leaves


def all_paths_between(G, a,b):
    # returns all paths beetween lists of nodes such that
    # path: p1, ..., pn is returned <=> p1 in a and pn in b.
    all_paths = []
    for source in a:
        for target in b:
            # Find all simple paths from source to target
            these_paths = [tuple(x) for x in nx.all_simple_edge_paths(G, source=source, target=target)]
            all_paths.extend(these_paths)
    return all_paths


def get_solution_nodes(G):
    solutions = [x for x, d in G.nodes(data=True) if d.get("is_solution", False)]
    return solutions

def get_clean_subgraph(G):
    # G without error nodes
    nodes = [x for x in G.nodes() if not G.nodes[x].get("is_error", False)]
    G_clean = nx.subgraph(G, nodes)
    return G_clean

def get_unsolved_roots(G):
    solutions = get_solution_nodes(G)
    roots = [x for x,y in G.nodes(data=True) if y.get("is_root",False)]
    # filter out roots that have a path to a solution node, since these are already "closed"
    open_roots = [r for r in roots if not any(nx.has_path(G, r, s) for s in solutions)]
    return open_roots

def get_solved_roots(G):
    roots = [x for x,y in G.nodes(data=True) if y.get("is_root",False)]
    o = set(get_unsolved_roots(G))
    return [x for x in roots if x not in o]

def subgraph_from(G, n):
    reachable_nodes = nx.descendants(G, n) | {n}
    reachable_subgraph = G.subgraph(reachable_nodes).copy()
    return reachable_subgraph

def get_winning_subgraph(G, min_expedience=-1):
    G = get_clean_subgraph(G)
    nodes = [n for n in G if (not G.nodes[n].get("is_root", False) and G.nodes[n].get("expedience",0) > min_expedience) or G.nodes[n].get("is_root",False)]
    G = nx.subgraph(G, nodes)
    solution_nodes = get_solution_nodes(G)
    # if there are no solution nodes, return None
    if len(solution_nodes) == 0:
        return None
    reachable_from_winners = nx.compose_all([nx.bfs_tree(G, n, reverse=True) for n in solution_nodes])
    W_pred = nx.subgraph(G, reachable_from_winners.nodes())
    return W_pred

##################################################################################################################################################################


def append_chat_message_to_convo(conv, chat_msg):
    # appends an openai ChatMessage to a "normal" list of conversation turns (dict)
    # returns a copy, so append is a little misleading
    d = dict(chat_msg)
    return conv + [{k:d[k] for k in conv[-1].keys()}]

# per convention, all "_responses" columns are ChatCompletionResponse columns
# all "[x]_follow_up_responses" have a parent "[x]_responses" column, with the previous conversation turns
def get_token_usage_aggregations(g):
    raise ValueError("deprecated")
    prefix_pt = "_prompt_token_usage"
    prefix_ct = "_completion_token_usage"
    cols = [c for c in g.obj.columns if c.endswith("_responses")]
    agg_pt = lambda x: x.apply(lambda y: y.usage.prompt_tokens).sum()
    agg_ct = lambda x: x.apply(lambda y: y.usage.completion_tokens).sum()
    d = {}
    for c in cols:
        base = c.split("_responses")[0]
        d[base + prefix_pt] = pd.NamedAgg(c, agg_pt)
        d[base + prefix_ct] = pd.NamedAgg(c, agg_ct)
    aggs = g.agg(**d)
    
    # account for double prompt token and completion token counts when doing follow ups
    pairs = [(c,p) for c in cols if "_follow_up" in c and (p:=c.replace("_follow_up","")) in cols]
    for x in pairs:
        t,p = [a.split("_responses")[0] for a in x]
        aggs[t+prefix_pt] = aggs.apply(lambda row: row[t+prefix_pt] - (row[p+prefix_pt] + row[p+prefix_ct]), axis=1)
    return aggs

