#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@author: Haoyue
@file: MAG_tools.py
@time: 9/21/24 05:14
@desc: 
"""

import networkx as nx
from itertools import chain, combinations, permutations, product
def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))

AROW, DASH, CIRC = 'AROW', 'DASH', 'CIRC'
LEFT, RIGHT = 'LEFT', 'RIGHT'

def get_PAG_from_skeleton_and_sepsets(
    nodelist,                      # enumerate of nodes
    skeleton_edges,                # enumerate of tuples of nodes; no need to be symmetric
    sepsets,                       # dict{(node1, node2): set_of_nodes}; assert all not in skeleton_edges should have a sepset
    sure_direct_edges=None,
    sure_no_latents=False
):
    assert set().union(*skeleton_edges) <= set(nodelist)
    assert all(set(k) | set(v) <= set(nodelist) for k, v in sepsets.items())
    if sure_direct_edges is not None: assert set().union(*sure_direct_edges) <= set(nodelist)

    ALLNODES = set(nodelist)
    CURREDGES, SEPSETS = {}, {}
    for x, y in skeleton_edges: CURREDGES[(x, y)] = CURREDGES[(y, x)] = (CIRC, CIRC)
    for (node1, node2), Z in sepsets.items(): SEPSETS[(node1, node2)] = SEPSETS[(node2, node1)] = set(Z)
    assert len(set(CURREDGES.keys()) & set(SEPSETS.keys())) == 0
    assert set(CURREDGES.keys()) | set(SEPSETS.keys()) == {(x, y) for x, y in product(nodelist, nodelist) if x != y}

    def get_curr_edge_type(node1, node2, end=LEFT):
        if (node1, node2) not in CURREDGES: return False
        if end == LEFT:
            return CURREDGES[(node1, node2)][0]
        elif end == RIGHT:
            return CURREDGES[(node1, node2)][1]
        assert False

    def update_edge(node1, node2, type1, type2):
        if type1 is None: type1 = CURREDGES[(node1, node2)][0]
        if type2 is None: type2 = CURREDGES[(node1, node2)][1]
        CURREDGES[(node1, node2)] = (type1, type2)
        CURREDGES[(node2, node1)] = (type2, type1)

    def _RO():
        # print('RO...')
        for alpha, beta in combinations(ALLNODES, 2):  # safe here; it's symmetric
            if (alpha, beta) in CURREDGES: continue
            for gamma in ALLNODES - {alpha, beta}:
                if (alpha, gamma) in CURREDGES and (beta, gamma) in CURREDGES:
                    gamma_not_in_sepset = gamma not in SEPSETS[(alpha, beta)]
                    if gamma_not_in_sepset:
                        update_edge(alpha, gamma, None, AROW)
                        update_edge(beta, gamma, None, AROW)

    def _R1():
        # print('R1...')
        # If α∗→ β◦−−∗ γ, and α and γ are not adjacent, then orient the triple as α∗→ β →γ
        changed_something = False
        for alpha in ALLNODES:
            for beta in [bt for bt in ALLNODES - {alpha} if get_curr_edge_type(alpha, bt, RIGHT) == AROW]:
                for gamma in [gm for gm in ALLNODES - {alpha, beta} if
                              get_curr_edge_type(beta, gm, LEFT) == CIRC and (alpha, gm) not in CURREDGES]:
                    update_edge(beta, gamma, DASH, AROW)
                    changed_something = True
        return changed_something

    def _R2():
        # print('R2...')
        # If α→β∗→ γ or α∗→ β →γ, and α ∗−◦ γ,then orient α ∗−◦ γ as α∗→γ.
        changed_something = False
        for alpha in ALLNODES:
            for beta in [bt for bt in ALLNODES - {alpha} if
                         get_curr_edge_type(alpha, bt, RIGHT) == AROW]:
                for gamma in [gm for gm in ALLNODES - {alpha, beta} if
                              get_curr_edge_type(beta, gm, RIGHT) == AROW]:
                    if get_curr_edge_type(alpha, gamma, RIGHT) == CIRC:
                        if get_curr_edge_type(alpha, beta, LEFT) == DASH or \
                           get_curr_edge_type(beta, gamma, LEFT) == DASH:
                            changed_something = True
                            update_edge(alpha, gamma, None, AROW)
        return changed_something

    def _R3():
        # print('R3...')
        # If α∗→ β ←∗γ, α ∗−◦ θ ◦−∗ γ, α and γ are not adjacent, and θ ∗−◦ β, then orient θ ∗−◦ β as θ∗→ β.
        changed_something = False
        for alpha, gamma in combinations(ALLNODES, 2):  # safe here; it's symmetric
            if (alpha, gamma) in CURREDGES: continue
            for beta in [bt for bt in ALLNODES - {alpha, gamma} if
                         get_curr_edge_type(alpha, bt, RIGHT) == AROW and \
                         get_curr_edge_type(bt, gamma, LEFT) == AROW]:
                for theta in [th for th in ALLNODES - {alpha, beta, gamma} if
                         get_curr_edge_type(alpha, th, RIGHT) == CIRC and \
                         get_curr_edge_type(th, gamma, LEFT) == CIRC]:
                    if get_curr_edge_type(theta, beta, RIGHT) == CIRC:
                        changed_something = True
                        update_edge(theta, beta, None, AROW)
        return changed_something

    def _R4():
        # print('R4...')
        # If u = <θ, ...,α,β,γ> is a discriminating path between θ and γ for β, and β◦−−∗γ;
        # then if β ∈ Sepset(θ,γ), orient β◦−−∗ γ as β →γ; otherwise orient the triple <α,β,γ> as α ↔β ↔γ.
        changed_something = False
        for theta in ALLNODES:
            for gamma in ALLNODES - {theta}:
                if (theta, gamma) in CURREDGES: continue
                for beta in {bt for bt in ALLNODES - {theta, gamma} if
                             get_curr_edge_type(bt, gamma, LEFT) == CIRC}:
                    gamma_parents = {af for af in ALLNODES - {theta, gamma, beta} if \
                                     get_curr_edge_type(af, gamma, LEFT) == DASH and
                                     get_curr_edge_type(af, gamma, RIGHT) == AROW}
                    if len(gamma_parents) < 1: continue
                    # to prevent from nx.all_simple_paths(self.mag_undirected_graph, ..) (too slow) we use subgraph to only allow paths through gamma_parents
                    subgraph = nx.Graph()   # undirected
                    subgraph.add_nodes_from(gamma_parents | {theta, beta})
                    subgraph.add_edges_from([(x, y) for x, y in combinations(gamma_parents | {theta, beta}, 2) if (x, y) in CURREDGES])
                    for theta_beta_path in nx.all_simple_paths(subgraph, theta, beta):
                        if len(theta_beta_path) < 3: continue
                        path = theta_beta_path + [gamma]
                        if all(
                                get_curr_edge_type(path[i - 1], path[i], RIGHT) == AROW and
                                get_curr_edge_type(path[i], path[i + 1], LEFT) == AROW
                                for i in range(1, len(path) - 2)
                        ):
                            changed_something = True
                            # print('    R4 applied!')
                            if beta in SEPSETS[(theta, gamma)]:
                                update_edge(beta, gamma, DASH, AROW)
                            else:
                                update_edge(path[-3], beta, AROW, AROW)
                                update_edge(beta, gamma, AROW, AROW)
        return changed_something

    def _R_no_latents():
        # when we are sure that there are no latents, we confirm all ⚬-> as ->
        changed_something = False
        for (node1, node2), (type1, type2) in CURREDGES.items():
            if type1 == CIRC and type2 == AROW:
                update_edge(node1, node2, DASH, AROW)
                changed_something = True
        return changed_something

    # ============================= main part ======================================
    # first apply background knowledge (for now we dont do consistency check; just trust it)
    if sure_direct_edges is not None:
        for node1, node2 in sure_direct_edges:
            CURREDGES[(node1, node2)] = (DASH, AROW)
            CURREDGES[(node2, node1)] = (AROW, DASH)

    # then fix the unshielded triples using observed CIs
    _RO()

    # then iteratively apply the rules until no more changes
    RULES = [_R1, _R2, _R3, _R4]
    if sure_no_latents: RULES.append(_R_no_latents)
    while True:
        changed_something = False
        for rule in RULES:
            changed_something |= rule()
        if not changed_something:
            break

    # finally, return the result
    pag_edges = {'->': set(), '<->': set(), '--': set(),
                      '⚬--': set(), '⚬->': set(), '⚬-⚬': set()}
    for (node1, node2), (type1, type2) in CURREDGES.items():
        if type1 == DASH and type2 == AROW:
            pag_edges['->'].add((node1, node2))
        elif type1 == AROW and type2 == AROW:
            pag_edges['<->'].add((node1, node2))  # has symmetric repeats
        elif type1 == DASH and type2 == DASH:
            pag_edges['--'].add((node1, node2))  # has symmetric repeats
        elif type1 == CIRC and type2 == DASH:
            pag_edges['⚬--'].add((node1, node2))
        elif type1 == CIRC and type2 == AROW:
            pag_edges['⚬->'].add((node1, node2))
        elif type1 == CIRC and type2 == CIRC:
            pag_edges['⚬-⚬'].add((node1, node2))  # has symmetric repeats

    return pag_edges

def get_skeleton_and_sepsets(nodelist, CI_tester):
    ALLNODES = set(nodelist)
    curr_skeleton = [(min(i, j), max(i, j)) for i, j in combinations(ALLNODES, 2)]
    curr_neighbors = {i: set(ALLNODES) - {i} for i in ALLNODES}
    Sepsets = {}
    l = -1
    while True:
        l += 1
        found_something = False
        visited_pairs = set()
        while True:
            this_i, this_j = None, None
            for i, j in curr_skeleton:
                if (i, j) in visited_pairs: continue
                assert j in curr_neighbors[i]
                if len(curr_neighbors[i]) - 1 >= l or len(curr_neighbors[j]) - 1 >= l:
                    this_i, this_j = i, j
                    found_something = True
                    break
            if this_i is None: break
            visited_pairs.add((this_i, this_j))
            choose_subset_from = set(map(frozenset, combinations(curr_neighbors[this_i] - {this_j}, l))) | \
                                 set(map(frozenset, combinations(curr_neighbors[this_j] - {this_i}, l)))
            for subset in choose_subset_from:
                if CI_tester(this_i, this_j, subset):
                    curr_skeleton.remove((this_i, this_j))
                    curr_neighbors[this_i].remove(this_j)
                    curr_neighbors[this_j].remove(this_i)
                    Sepsets[(this_i, this_j)] = Sepsets[(this_j, this_i)] = set(subset)
                    break
        if not found_something: break
    return set(curr_skeleton), Sepsets