from collections import Counter, defaultdict
import networkx as nx
import os, random, re, tqdm
from pgmpy.estimators.CITests import *
from pgmpy.estimators import PC
from pgmpy.estimators import ExpertKnowledge, StructureEstimator
from pgmpy.base import DAG
from pgmpy.base import PDAG
from itertools import permutations
from openai import OpenAI

os.environ["OPENAI_API_KEY"] = ...
client = OpenAI()

_ans_re = re.compile(r"<\s*answer\s*>\s*([ABCDE])\s*<\s*/\s*answer\s*>", re.I)

supported_tests = {
    "chi_square": chi_square,
    "g_sq": g_sq,
    "log_likelihood": log_likelihood,
    "modified_log_likelihood": modified_log_likelihood,
    "pearsonr": pearsonr,
    "pillai": pillai_trace,
    "gcm": gcm,
}

_CHAIN_PROMPT_TMPL_FULL = """
You are a senior researcher in causal discovery. We are studying the following dataset:

{data_desc}

The two target variables under review are {u} and {v}.

Conditional-independence tests mentioning these variables:

{ci_bullets}

Neighbour chain(s) that must normally remain non-collider:

{chains}

The nodes involved are described as below: 

{node_desc}

Choose one explanation that best fits domain knowledge and/or decides a CI test is unreliable (avoid selecting D or E unless other options are strongly against common sense):

A. Undecided. We don't know enough to confidently pick a directionality.
B. Changing the state of {u} causally affects {v}, and {v} causally affects {u_theOther_2v}.
C. Changing the state of {v} causally affects {u}, and {u} causally affects {v_theOther_2u}.
D. Changing the state of {u} causally affects {v}, and {u_theOther_2v} also causally affects {v}, **violating corresponding CI tests**.
E. Changing the state of {v} causally affects {u}, and {v_theOther_2u} also causally affects {u}, **violating corresponding CI tests**.

Think step-by-step before selecting:
1.⁠ ⁠Mechanisms – What known causal pathways (biological, physical, etc.) support each direction?
2.⁠ ⁠Counterfactual test – What would happen if we intervened on one node? What would we expect?
3.⁠ ⁠Empirical check – Point to one key piece of information that favors/weakens a direction.
4.⁠ ⁠Comparison – Briefly weigh A vs B vs C vs D vs E and choose the most plausible.

Return exactly three lines:
1.⁠ ⁠Reasoning in support of one direction.
2.⁠ ⁠Reasoning against the weaker/less plausible direction.
3.⁠ ⁠Final choice: ⁠ <Answer>A/B/C/D/E</Answer> ⁠
"""

_CHAIN_PROMPT_TMPL_None2u = """
You are a senior researcher in causal discovery. We are studying the following dataset:

{data_desc}

The two target variables under review are {u} and {v}.

Conditional-independence tests mentioning these variables:

{ci_bullets}

Neighbour chain(s) that must normally remain non-collider:

{chains}

The nodes involved are described as below: 

{node_desc}

Choose one explanation that best fits domain knowledge and/or decides a CI test is unreliable (avoid selecting D unless other options are strongly against common sense):

A. Undecided. We don't know enough to confidently pick a directionality.
B. Changing the state of {u} causally affects {v}, and {v} causally affects {u_theOther_2v}.
C. Changing the state of {v} causally affects {u}.
D. Changing the state of {u} causally affects {v}, and {u_theOther_2v} also causally affects {v}, **violating corresponding CI tests**.

Think step-by-step before selecting:
1.⁠ ⁠Mechanisms – What known causal pathways (biological, physical, etc.) support each direction?
2.⁠ ⁠Counterfactual test – What would happen if we intervened on one node? What would we expect?
3.⁠ ⁠Empirical check – Point to one key piece of information that favors/weakens a direction.
4.⁠ ⁠Comparison – Briefly weigh A vs B vs C vs D and choose the most plausible.

Return exactly three lines:
1.⁠ ⁠Reasoning in support of one direction.
2.⁠ ⁠Reasoning against the weaker/less plausible direction.
3.⁠ ⁠Final choice: ⁠ <Answer>A/B/C/D</Answer> ⁠
"""

_CHAIN_PROMPT_TMPL_None2v = """
You are a senior researcher in causal discovery. We are studying the following dataset:

{data_desc}

The two target variables under review are {u} and {v}.

Conditional-independence tests mentioning these variables:

{ci_bullets}

Neighbour chain(s) that must normally remain non-collider:

{chains}

The nodes involved are described as below: 

{node_desc}

Choose one explanation that best fits domain knowledge and/or decides a CI test is unreliable (avoid selecting D unless other options are strongly against common sense):

A. Undecided. We don't know enough to confidently pick a directionality.
B. Changing the state of {u} causally affects {v}.
C. Changing the state of {v} causally affects {u}, and {u} causally affects {v_theOther_2u}.
D. Changing the state of {v} causally affects {u}, and {v_theOther_2u} also causally affects {u}, **violating corresponding CI tests**.

Think step-by-step before selecting:
1.⁠ ⁠Mechanisms – What known causal pathways (biological, physical, etc.) support each direction?
2.⁠ ⁠Counterfactual test – What would happen if we intervened on one node? What would we expect?
3.⁠ ⁠Empirical check – Point to one key piece of information that favors/weakens a direction.
4.⁠ ⁠Comparison – Briefly weigh A vs B vs C vs D and choose the most plausible.

Return exactly three lines:
1.⁠ ⁠Reasoning in support of one direction.
2.⁠ ⁠Reasoning against the weaker/less plausible direction.
3.⁠ ⁠Final choice: ⁠ <Answer>A/B/C/D</Answer> ⁠
"""

_CHAIN_PROMPT_TMPL_None = """
You are a senior researcher in causal discovery. We are studying the following dataset:

{data_desc}

The two target variables under review are {u} and {v}.

The nodes involved are described as below: 

{node_desc}

Choose one explanation that best fits domain knowledge:

A. Undecided. We don't know enough to confidently pick a directionality.
B. Changing the state of {u} causally affects {v}.
C. Changing the state of {v} causally affects {u}.

Think step-by-step before selecting:
1.⁠ ⁠Mechanisms – What known causal pathways (biological, physical, etc.) support each direction?
2.⁠ ⁠Counterfactual test – What would happen if we intervened on one node? What would we expect?
3.⁠ ⁠Empirical check – Point to one key piece of information that favors/weakens a direction.
4.⁠ ⁠Comparison – Briefly weigh A vs B vs C and choose the most plausible.

Return exactly three lines:
1.⁠ ⁠Reasoning in support of one direction.
2.⁠ ⁠Reasoning against the weaker/less plausible direction.
3.⁠ ⁠Final choice: ⁠ <Answer>A/B/C</Answer> ⁠
"""

def build_skeleton_CI(
    data = None,
    variant="orig",
    ci_test="chi_square",
    significance_level=0.01,
    max_cond_vars=5,
    expert_knowledge=None,
    enforce_expert_knowledge=False,
    **kwargs,
):
    """
    Estimates a graph skeleton (UndirectedGraph) from a set of independencies
    using (the first part of) the PC algorithm. The independencies can either be
    provided as an instance of the `Independencies`-class or by passing a
    decision function that decides any conditional independency assertion.
    Returns a tuple `(skeleton, separating_sets)`.

    If an Independencies-instance is passed, the contained IndependenceAssertions
    have to admit a faithful BN representation. This is the case if
    they are obtained as a set of d-separations of some Bayesian network or
    if the independence assertions are closed under the semi-graphoid axioms.
    Otherwise, the procedure may fail to identify the correct structure.

    References
    ----------
    [1] Neapolitan, Learning Bayesian Networks, Section 10.1.2, Algorithm 10.2 (page 550)
        http://www.cs.technion.ac.il/~dang/books/Learning%20Bayesian%20Networks(Neapolitan,%20Richard).pdf
    [2] Koller & Friedman, Probabilistic Graphical Models - Principles and Techniques, 2009
        Section 3.4.2.1 (page 85), Algorithm 3.3
    """
    if data is None:
        raise ValueError(
            f"data must be specified."
        )
    lim_neighbors = 0
    separating_sets = dict()
    p_value_sets = dict()
    ci_test = supported_tests[ci_test]

    if expert_knowledge is None:
        expert_knowledge = ExpertKnowledge()

    if expert_knowledge.search_space:
        expert_knowledge.limit_search_space(data.columns)

    variables = list(data.columns.values)
    # Step 1: Initialize a fully connected undirected graph
    graph = nx.complete_graph(n=variables, create_using=nx.Graph)
    temporal_ordering = expert_knowledge.temporal_ordering
    if enforce_expert_knowledge:
        graph.remove_edges_from(expert_knowledge.forbidden_edges)

    # Exit condition: 1. If all the nodes in graph has less than `lim_neighbors` neighbors.
    #             or  2. `lim_neighbors` is greater than `max_conditional_variables`.
    while not all(
        [len(list(graph.neighbors(var))) < lim_neighbors for var in variables]
    ):
        # Step 2: Iterate over the edges and find a conditioning set of
        # size `lim_neighbors` which makes u and v independent.
        if variant == "orig":
            # Standard PC: remove an edge as soon as you find ANY separating set at this k.
            for u, v in graph.edges():
                if (enforce_expert_knowledge is False) or (
                    (u, v) not in expert_knowledge.required_edges
                ):
                    for separating_set in PC._get_potential_sepsets(
                        u, v, temporal_ordering, graph, lim_neighbors
                    ):
                        p_value = ci_test(
                            u,
                            v,
                            separating_set,
                            data=data,
                            independencies=None,
                            significance_level=significance_level,
                            boolean=False,
                            **kwargs,
                        )[1]
                        if p_value >= significance_level:
                            separating_sets[frozenset((u, v))] = [separating_set]
                            p_value_sets[frozenset((u, v))] = p_value
                            graph.remove_edge(u, v)
                            break
        elif variant == "stable" or variant == "conservative":
            # PC-stable (Colombo & Maathuis): freeze adjacency per level.
            # Do NOT remove edges immediately; collect removals and apply after scanning all pairs at this k.
            # CPC-style preparation: use the PC-stable *skeleton* procedure,
            # BUT collect *all minimal separating sets* at the first k where independence occurs.
            all_sepsets_at_k = {}
            current_edges = list(graph.edges())
            for u, v in current_edges:
                if (enforce_expert_knowledge is False) or (
                    (u, v) not in expert_knowledge.required_edges
                ):
                    potential_sepsets = PC._get_potential_sepsets(
                        u, v, temporal_ordering, graph, lim_neighbors
                    )
                    seps_found = []
                    pvals_found = []
                    for separating_set in potential_sepsets:
                        p_value = ci_test(
                            u,
                            v,
                            separating_set,
                            data=data,
                            significance_level=significance_level,
                            boolean=False,
                            **kwargs,
                        )[1]
                        if p_value >= significance_level:
                            seps_found.append(separating_set)
                            pvals_found.append(p_value)
                    if seps_found:
                        if variant == "stable":
                            # Keep exactly ONE sepset (choose the one with the highest p-value).
                            best_idx = max(range(len(pvals_found)), key=pvals_found.__getitem__)
                            all_sepsets_at_k[(u, v)] = [seps_found[best_idx]]
                            p_value_sets[frozenset((u, v))] = pvals_found[best_idx]
                        else:
                            # conservative: collect ALL minimal sepsets at first k where independence holds
                            min_len = min(len(s) for s in seps_found)
                            # all_sepsets_at_k[(u, v)] = [s for s in seps_found if len(s) == min_len]
                            # p_value_sets[frozenset((u, v))] = max(pvals_found)
                            sp_ = [(s, p) for s, p in zip(seps_found, pvals_found) if len(s) == min_len]
                            sp_.sort(key=lambda sp: sp[1], reverse=True)
                            all_sepsets_at_k[(u, v)] = [s for s, _ in sp_]
                            p_value_sets[frozenset((u, v))] = sp_[0][1] if sp_ else None
            graph.remove_edges_from(all_sepsets_at_k.keys())
            for (u, v), seps in all_sepsets_at_k.items():
                if frozenset((u, v)) in separating_sets:
                    prev = separating_sets[frozenset((u, v))]
                    min_prev = min(len(s) for s in prev)
                    seps = [s for s in seps if len(s) <= min_prev]
                separating_sets[frozenset((u, v))] = seps
        else:
            raise ValueError(
                f"variant must be one of (orig, stable, conservative). Got: {variant}"
            )
        if lim_neighbors >= max_cond_vars:
            break
        lim_neighbors += 1
    return graph, separating_sets, p_value_sets

def precision_recall_adj(inferred: nx.DiGraph, truth: nx.DiGraph):
    """
    Precision / recall of *undirected* adjacencies vs. ground-truth skeleton.
    """
    pred_edges = {frozenset(e) for e in inferred.edges()}
    true_edges = {frozenset(e) for e in truth.edges()}
    tp = len(pred_edges & true_edges)
    fp = len(pred_edges - true_edges)
    fn = len(true_edges - pred_edges)
    prec = tp / (tp + fp) if (tp + fp) else 0.0
    rec  = tp / (tp + fn) if (tp + fn) else 0.0
    return prec, rec, tp, fp, fn

def fraction_valid_dsep(z_subsets, true_dag):
    """
    Proportion of (X,Y,Z) such that X ⫫ Y | Z in the *ground-truth* DAG.
    """
    valid = 0
    true_dag = DAG(true_dag)
    for (x, y), Z in z_subsets.items():
        # pgmpy: active trail = d-connected
        if not true_dag.is_dconnected(x, y, observed=set(Z)):
            valid += 1
    frac = valid / len(z_subsets) if z_subsets else 0.0
    return frac


def orient_colliders(skeleton, separating_sets):
    # https://github.com/pgmpy/pgmpy/blob/8f75787e3a22e14daa8cebc676e7577492c85935/pgmpy/estimators/PC.py
    pdag = skeleton.to_directed()
    # 1) for each X-Z-Y, if Z not in the separating set of X,Y, then orient edges
    # as X->Z<-Y (Algorithm 3.4 in Koller & Friedman PGM, page 86)
    for X, Y in permutations(sorted(pdag.nodes()), 2):
        if not skeleton.has_edge(X, Y):
            for Z in set(skeleton.neighbors(X)) & set(skeleton.neighbors(Y)):
                if not all(Z not in S for S in separating_sets[frozenset((X, Y))]):
                    continue
                if (pdag.has_edge(Z, X) and pdag.has_edge(X, Z) and
                    pdag.has_edge(Z, Y) and pdag.has_edge(Y, Z)):
                    pdag.remove_edge(Z, X)
                    pdag.remove_edge(Z, Y)
    edges = set(pdag.edges())
    undirected_edges = {tuple(sorted((u, v))) for u, v in edges if (v, u) in edges}
    directed_edges = {(u, v) for u, v in edges if (v, u) not in edges}
    pdag_oriented = PDAG(directed_ebunch=directed_edges,
                         undirected_ebunch=undirected_edges)
    pdag_oriented.add_nodes_from(pdag.nodes())
    return pdag_oriented
                    
def orient_colliders_pval(skeleton, separating_sets, p_value_sets):
    pdag = skeleton.to_directed()
    # Process each pair (X, Y) in order of decreasing p-value
    for pair in sorted(p_value_sets, key=lambda pair: p_value_sets[pair], reverse=True):
        X, Y = sorted(pair)
        if not skeleton.has_edge(X, Y):
            for Z in set(skeleton.neighbors(X)) & set(skeleton.neighbors(Y)):
                if not all(Z not in S for S in separating_sets[pair]):
                    continue
                if (pdag.has_edge(Z, X) and pdag.has_edge(X, Z) and
                    pdag.has_edge(Z, Y) and pdag.has_edge(Y, Z)):
                    pdag.remove_edge(Z, X)
                    pdag.remove_edge(Z, Y)
    edges = set(pdag.edges())
    undirected_edges = {tuple(sorted((u, v))) for u, v in edges if (v, u) in edges}
    directed_edges = {(u, v) for u, v in edges if (v, u) not in edges}
    pdag_oriented = PDAG(directed_ebunch=directed_edges,
                         undirected_ebunch=undirected_edges)
    pdag_oriented.add_nodes_from(pdag.nodes())
    return pdag_oriented

def order_PC(pdag_PC, vote):
    G = nx.DiGraph()
    G.add_nodes_from(pdag_PC.nodes())
    for u, v in pdag_PC.to_undirected().edges():
        uv_fwd = vote.get((u, v, 'fwd'), 0)
        uv_bwd = vote.get((u, v, 'bwd'), 0)
        vu_fwd = vote.get((v, u, 'fwd'), 0)
        vu_bwd = vote.get((v, u, 'bwd'), 0)
        if uv_fwd >= uv_bwd and vu_fwd <= vu_bwd:
            net = (uv_fwd + vu_bwd) - (uv_bwd + vu_fwd)
            G.add_edge(u, v, weight=net)
        elif uv_fwd <= uv_bwd and vu_fwd >= vu_bwd:
            net = (uv_bwd + vu_fwd) - (uv_fwd + vu_bwd)
            G.add_edge(v, u, weight=net)
        
    while not nx.is_directed_acyclic_graph(G):
        cycle = nx.find_cycle(G)
        weakest = min(cycle, key=lambda e: G.edges[e]['weight'])
        G.remove_edge(*weakest)
    
    order = list(nx.topological_sort(G))
    edges = set(pdag_PC.edges())
    undirected_edges = {tuple(sorted((u, v))) for u, v in edges if (v, u) in edges}
    directed_edges = {(u, v) for u, v in edges if (v, u) not in edges}

    for u, v in undirected_edges.copy():
        if u in order and v in order:
            if order.index(u) < order.index(v):
                directed_edges.add((u, v))
                # undirected_edges.remove(tuple(sorted((u, v))))
            else:
                directed_edges.add((v, u))
                # undirected_edges.remove(tuple(sorted((u, v))))
    G_order = nx.DiGraph()
    G_order.add_nodes_from(pdag_PC.nodes())
    G_order.add_edges_from(directed_edges)
    return G_order

def find_halfdir_triples(directed_edges, undirected_edges, separating_sets, p_value_sets):
    undirected = {tuple(sorted(e)) for e in undirected_edges}
    dir_out = defaultdict(set)
    dir_any = set()
    for u, v in directed_edges:
        dir_out[u].add(v)
        dir_any.add(tuple(sorted((u, v))))
    adj_any = dir_any | {edge for edge in undirected}
    new_dir = set()
    for pair in sorted(p_value_sets, key=lambda pair: p_value_sets[pair], reverse=True):
        X, Y = sorted(pair)
        if (X, Y) in adj_any:
            continue
        for x_cur, y_cur in ((X, Y), (Y, X)):
            # 1) x -> z is already directed (consider directions we added so far, too)
            for z in list(dir_out.get(x_cur, ())):
                if z in (x_cur, y_cur):
                    continue
                # 2) y - z must currently be undirected (eligible to orient)
                if tuple(sorted((y_cur, z))) not in undirected:
                    continue
                # decide using only the arg-max separating set(s) for {x,y}
                S_list = separating_sets[pair]
                # If *all* arg-max sepsets contain z -> Rule 2: orient z -> y
                if S_list and all(z in S for S in S_list):
                    if not ((y_cur in dir_out.get(z, ())) or (z in dir_out.get(y_cur, ())) or ((y_cur, z) in new_dir) or ((z, y_cur) in new_dir)):
                        new_dir.add((z, y_cur))
                        dir_out[z].add(y_cur)
                        dir_any.add(tuple(sorted((z, y_cur))))
                        undirected.remove(tuple(sorted((z, y_cur))))
    for pair in sorted(p_value_sets, key=lambda pair: p_value_sets[pair], reverse=True):
        X, Y = sorted(pair)
        if(X, Y) in adj_any:
            continue
        for x_cur, y_cur in ((X, Y), (Y, X)):
            # 1) x -> z is already directed (consider directions we added so far, too)
            for z in list(dir_out.get(x_cur, ())):
                if z in (x_cur, y_cur):
                    continue
                # 2) y - z must currently be undirected (eligible to orient)
                if tuple(sorted((y_cur, z))) not in undirected:
                    continue
                # decide using only the arg-max separating set(s) for {x,y}
                S_list = separating_sets[pair]
                # none of the arg-max sepsets contain z -> v-structure: orient y -> z
                if S_list and all(z not in S for S in S_list):
                    if not ((y_cur in dir_out.get(z, ())) or (z in dir_out.get(y_cur, ())) or ((y_cur, z) in new_dir) or ((z, y_cur) in new_dir)):
                        new_dir.add((y_cur, z))
                        dir_out[y_cur].add(z)
                        dir_any.add(tuple(sorted((y_cur, z))))
                        undirected.remove(tuple(sorted((y_cur, z))))
    return new_dir

def find_undir_v_triples(undirected_edges, separating_sets, p_value_sets):
    undirected_edges = set(undirected_edges)
    neighbors = {}
    for u, v in undirected_edges:
        neighbors.setdefault(u, set()).add(v)
        neighbors.setdefault(v, set()).add(u)
    new_directed_edges = set()
    # Process each pair (X, Y) in order of increasing p-value
    for pair in sorted(p_value_sets, key=lambda pair: p_value_sets[pair], reverse=True):
        X, Y = sorted(pair)
        common_Z = neighbors.get(X, set()) & neighbors.get(Y, set())
        for Z in common_Z:
            if all(Z not in S for S in separating_sets[pair]): 
                e = [tuple(sorted((X, Z))), tuple(sorted((Y, Z)))]
                if e[0] in undirected_edges and e[1] in undirected_edges:
                    undirected_edges.remove(e[0])
                    undirected_edges.remove(e[1])
                    new_directed_edges.add((X, Z))
                    new_directed_edges.add((Y, Z))
    # print('undir_v_triples', new_directed_edges)
    return new_directed_edges
# ORIENT (d_triples z -> y) + consistent part of (v_triples y -> z)
# ORIENT v_triples in undir_edges

def run_prompt(prompt, temperature, model):
    try:
        resp = client.chat.completions.create(
            model=model,
            temperature=temperature,
            messages=[{"role": "user", "content": prompt}],
        )
        text = resp.choices[0].message.content
        #print("PROMPT:", prompt)
        #print("TEXT:", text)
        m = _ans_re.search(text)
        return m.group(1).upper() if m else None
    except Exception as e:
        print("LLM call failed:", e)
        return None

def evidence_for_edge(u, v, undir_graph, separating_sets):
    u_theOther_2v = []
    v_theOther_2u = []
    for pair, S in separating_sets.items():
        set_pair = set(pair)
        if u in pair and all(v in s for s in S):
            theOther = [node for node in pair if node != u][0]
            if undir_graph.has_edge(v, theOther) or undir_graph.has_edge(theOther, v):
                u_theOther_2v.append(theOther)
        elif v in pair and all(u in s for s in S):
            theOther = [node for node in pair if node != v][0]
            if undir_graph.has_edge(u, theOther) or undir_graph.has_edge(theOther, u):
                v_theOther_2u.append(theOther)
    return u_theOther_2v, v_theOther_2u

def _ci_bullets_for_edge(u, v, u_theOther_2v, v_theOther_2u, separating_sets, pval_sets):
    lines = []
    if len(u_theOther_2v) > 0: 
        for theOther in u_theOther_2v:
            S = separating_sets[frozenset((u, theOther))]
            p = pval_sets[frozenset((u, theOther))]
            S_txt = ", ".join(S[0])
            lines.append(f"{u} ⟂ {u_theOther_2v} | {{{S_txt}}} isn't rejected because p-value={p:.3g}, the edge “{u} --- {u_theOther_2v}” is removed")
    if len(v_theOther_2u) > 0: 
        for theOther in v_theOther_2u:
            S = separating_sets[frozenset((v, theOther))]
            p = pval_sets[frozenset((v, theOther))]
            S_txt = ", ".join(S[0])
            lines.append(f"{v} ⟂ {v_theOther_2u} | {{{S_txt}}} isn't rejected because p-value={p:.3g}, the edge “{v} --- {v_theOther_2u}” is removed")
    if len(lines) > 0:
        return "\n".join(lines)
    else: 
        return ""

def vote_edges(data_description, node_description, undir_graph, separating_sets, pval_sets, n_votes, temperature, model):
    pdag = undir_graph.to_directed()

    vote_dir = Counter()       # key = (cause, effect)

    for _ in range(n_votes):
        for u, v in list(pdag.edges()):
            u_theOther_2v, v_theOther_2u = evidence_for_edge(u, v, undir_graph, separating_sets)
            related_nodes = set([u, v] + u_theOther_2v + v_theOther_2u)
            node_desc = "\n".join(f"- **{node}**: {node_description[node]}" for node in related_nodes)
            if len(u_theOther_2v) == 0 and len(v_theOther_2u) == 0:
                prompt = _CHAIN_PROMPT_TMPL_None.format(data_desc = data_description, node_desc = node_desc, u = u, v = v)
                ans = run_prompt(prompt, temperature, model)
                if ans == "B":
                    vote_dir[(u, v, 'fwd')] += 1
                elif ans == "C":
                    vote_dir[(u, v, 'bwd')] += 1
                # print("(2) ans: ", ans)
            elif len(u_theOther_2v) > 0 and len(v_theOther_2u) == 0:
                chains = f"{u} —-- {v} —-- {u_theOther_2v}"
                ci_bullets = _ci_bullets_for_edge(u, v, u_theOther_2v, v_theOther_2u, separating_sets, pval_sets)
                prompt = _CHAIN_PROMPT_TMPL_None2u.format(data_desc = data_description, node_desc = node_desc, u = u, v = v, u_theOther_2v = u_theOther_2v, chains = chains, ci_bullets = ci_bullets)
                ans = run_prompt(prompt, temperature, model)
                if ans == "B":
                    vote_dir[(u, v, 'fwd')] += 1
                elif ans == "C":
                    vote_dir[(u, v, 'bwd')] += 1
                elif ans == "D":
                    vote_dir[(u, v, 'fwd')] += 1
                # print("(3) ans: ", ans)
            elif len(u_theOther_2v) == 0 and len(v_theOther_2u) > 0:
                chains = f"{v_theOther_2u} --- {u} —-- {v}"
                ci_bullets = _ci_bullets_for_edge(u, v, u_theOther_2v, v_theOther_2u, separating_sets, pval_sets)
                prompt = _CHAIN_PROMPT_TMPL_None2v.format(data_desc = data_description, node_desc = node_desc, u = u, v = v, v_theOther_2u = v_theOther_2u, chains = chains, ci_bullets = ci_bullets)
                ans = run_prompt(prompt, temperature, model)
                if ans == "B":
                    vote_dir[(u, v, 'fwd')] += 1
                elif ans == "C":
                    vote_dir[(u, v, 'bwd')] += 1
                elif ans == "D":
                    vote_dir[(u, v, 'bwd')] += 1
                # print("(3) ans: ", ans)
            else:
                chains = f"{v_theOther_2u} --- {u} —-- {v} —-- {u_theOther_2v}"
                ci_bullets = _ci_bullets_for_edge(u, v, u_theOther_2v, v_theOther_2u, separating_sets, pval_sets)
                prompt = _CHAIN_PROMPT_TMPL_FULL.format(data_desc = data_description, node_desc = node_desc, u = u, v = v, u_theOther_2v = u_theOther_2v, v_theOther_2u = v_theOther_2u, chains = chains, ci_bullets = ci_bullets)
                ans = run_prompt(prompt, temperature, model)
                if ans == "B":
                    vote_dir[(u, v, 'fwd')] += 1
                elif ans == "C":
                    vote_dir[(u, v, 'bwd')] += 1
                elif ans == "D":
                    vote_dir[(u, v, 'fwd')] += 1
                elif ans == "E":
                    vote_dir[(u, v, 'bwd')] += 1
                # print("(4) ans: ", ans)
    return vote_dir

def edges_from_votes(undir_input, vote_dir):
    directed_edges = []
    undirected_edges = []
    for u, v in list(undir_input):
        uv_fwd_votes   = vote_dir.get((u, v, 'fwd'), 0)
        uv_bwd_votes   = vote_dir.get((u, v, 'bwd'), 0)
        vu_fwd_votes   = vote_dir.get((v, u, 'fwd'), 0)
        vu_bwd_votes   = vote_dir.get((v, u, 'bwd'), 0)
        if uv_fwd_votes > uv_bwd_votes and vu_fwd_votes < vu_bwd_votes:
            directed_edges.append((u, v))
        elif vu_fwd_votes > vu_bwd_votes and uv_fwd_votes < uv_bwd_votes:
            directed_edges.append((v, u))
        else:
            undirected_edges.append(tuple(sorted((u, v))))
    return directed_edges, undirected_edges

def orient_edges(undir_graph, directed_edges, undirected_edges, separating_sets, pval_sets):
    undirected_edges = set(undirected_edges)
    cnt_undirected = 0
    while len(undirected_edges) != cnt_undirected:
        cnt_undirected = len(undirected_edges)
        while True:
            start_len = len(undirected_edges)
            # If X -> Z -> Y and X - Y => X -> Y
            for z in undir_graph.nodes():
                xs = [u for (u, v) in directed_edges if v == z]
                ys = [v for (u, v) in directed_edges if u == z]
                for x in xs:
                    for y in ys:
                        tpl = tuple(sorted((x, y)))
                        if tpl in undirected_edges:
                            directed_edges.append((x, y))
                            undirected_edges.remove(tpl)
            
            # If X -> Z - Y, Z in separating_sets(X, Y)?
            new_dir_ = find_halfdir_triples(directed_edges, undirected_edges, separating_sets, pval_sets)
            for dir_ in new_dir_:
                directed_edges.append(dir_)
                undirected_edges.remove(tuple(sorted(dir_)))
            
            # If X - Z - Y, Z in separating_sets(X, Y)?
            new_dir__ = find_undir_v_triples(undirected_edges, separating_sets, pval_sets)
            for dir__ in new_dir__:
                directed_edges.append(dir__)
                undirected_edges.remove(tuple(sorted(dir__)))
            if len(undirected_edges) == start_len:
                break  # steps 1–3 converged

        # Orient with min violating
        for u, v in undirected_edges.copy():
            tpl = tuple(sorted((u, v)))
            u_theOther_2v, v_theOther_2u = evidence_for_edge(u, v, undir_graph, separating_sets)
            uv_violated = 0
            vu_violated = 0
            for node in u_theOther_2v:
                if (node, v) in directed_edges:
                    uv_violated = uv_violated + 1
            for node in v_theOther_2u:
                if (node, u) in directed_edges:
                    vu_violated = vu_violated + 1
            # print(f'{v_theOther_2u} --- {u} —-- {v} —-- {u_theOther_2v}', uv_violated, vu_violated)
            if uv_violated < vu_violated:
                directed_edges.append((u, v))
                undirected_edges.remove(tpl)
                # print('#violated', uv_violated, '<', vu_violated, (u, v))
            elif uv_violated > vu_violated:
                directed_edges.append((v, u))
                undirected_edges.remove(tpl)
                # print('#violated', uv_violated, '>', vu_violated, (v, u))
    pdag_oriented = PDAG(directed_ebunch=set(directed_edges),
                         undirected_ebunch=set(undirected_edges))
    pdag_oriented.add_nodes_from(list(undir_graph.nodes()))
    return pdag_oriented

def shd_pdag(truth, pred_dir, pred_und):
    """
    SHD between *pred* (possibly PDAG) and *truth* (DAG).

    Rules:
    * Missing edge   → +1 (add undirected + orient)
    * Extra edge     → +1 (remove edge)
    * Reversed edge  → +1 (remove one dir + add the other)
    * Undirected edge  → +0.5 (add direction)
    """
    # edge inventories
    truth_dir = set(truth.edges())

    shd = 0.

    # --- 1. handle each true directed edge ---------------------------------
    for u, v in truth_dir:
        if (u, v) in pred_dir:
            continue                    # perfect
        elif (v, u) in pred_dir:          # reversed
            shd += 1
        else:                           # missing
            shd += 1

    # --- 2. penalise *extra* edges in prediction ---------------------------
    for u, v in pred_dir:
        if (u, v) in truth_dir:                 # correct already counted
            continue
        elif (v, u) in truth_dir:                 # reversed already counted
            continue
        else:
            shd += 1                                # extra directed edge
    for uv in pred_und:
        u, v = tuple(uv)
        if (u, v) in truth_dir or (v, u) in truth_dir:
            shd += 0.5                            # valid undirected edge
        else:
            shd += 1                                # extra undirected edge

    return shd

def accuracy_metrics(truth, pred_dir, pred_und):
    truth_edges = set(truth.edges())
    pred_dir = set(pred_dir)
    pred_und = set(pred_und)
    tp = len(pred_dir & truth_edges)
    fp = len(pred_dir - truth_edges)
    fn = len(truth_edges - pred_dir)
    for uv in pred_und:
        u, v = tuple(uv)
        if (u, v) in truth_edges or (v, u) in truth_edges:
            tp += 0.5
            fn += 0.5
        else: 
            fp += 1      
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    # # Structural Hamming Distance (undirected mismatch counts)
    # inferred_undirected = set(frozenset(edge) for edge in inferred.edges())
    # truth_undirected = set(frozenset(edge) for edge in truth.edges())
    # shd = len(inferred_undirected.symmetric_difference(truth_undirected))
    shd = shd_pdag(truth, pred_dir, pred_und)

    return {
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "true_positives": tp,
        "false_positives": fp,
        "false_negatives": fn,
        "shd": shd,
        "num_inferred_edges": (len(pred_dir) + len(pred_und)),
        "num_truth_edges": len(truth_edges)
    }

def split_directed_undirected_edges(graph):
    directed = []
    undirected = []
    for u, v in graph.edges():
        if graph.has_edge(v, u):
            if (v, u) not in undirected:  # avoid duplicates
                undirected.append((u, v))
        else:
            directed.append((u, v))
    return directed, undirected
