import pytest

from do_shap.frontiers import *


@pytest.fixture
def graph3() -> DAG:
    graph = DAG(4)
    for edge in (
        (0, 1),
        (0, 2),
        (1, 2),
        (1, 3),
        (2, 3),
    ):
        graph.add_edge(Edge(*edge))

    return graph


@pytest.fixture
def graph5() -> DAG:
    graph = DAG(6)
    for edge in (
        (0, 1),
        (0, 2),
        (1, 3),
        (2, 3),
        (3, 4),
        (4, 5)
    ):
        graph.add_edge(Edge(*edge))

    return graph


def test_parents(graph5: DAG) -> None:
    for node, parents in (
        (0, ()),
        (1, (0,)),
        (2, (0,)),
        (3, (1, 2)),
        (4, (3,)),
        (5, (4,))
    ):
        assert set(graph5.parents(node)) == set(parents)


def test_children(graph5: DAG) -> None:
    for node, children in (
        (0, (1, 2)),
        (1, (3,)),
        (2, (3,)),
        (3, (4,)),
        (4, (5,)),
        (5, ())
    ):
        assert set(graph5.children(node)) == set(children)


def test_ancestors(graph5: DAG) -> None:
    for node, ancestors in (
        (0, (0,)),
        (1, (0, 1)),
        (2, (0, 2)),
        (3, (0, 1, 2, 3)),
        (4, (0, 1, 2, 3, 4)),
        (5, (0, 1, 2, 3, 4, 5))
    ):
        assert graph5.ancestors(node) == set(ancestors)


def test_descendants(graph5: DAG) -> None:
    for node, descendants in (
        (0, (0, 1, 2, 3, 4, 5)),
        (1, (1, 3, 4, 5)),
        (2, (2, 3, 4, 5)),
        (3, (3, 4, 5)),
        (4, (4, 5)),
        (5, (5,))
    ):
        assert graph5.descendants(node) == set(descendants)


def test_fr1(graph3: DAG) -> None:
    fr1 = FR1(graph3)
    for comb in parts_of(graph3.X):
        fr1.run(comb)

    # Any sets where x (max) is parent of Y are omitted (trivially False)
    assert fr1.fr_cache == {
        (0,): False,
        (1, 0): False,
        (2, 0): False,
        (2, 1, 0): True
    }

    assert fr1.v_cache == {
        (),
        (0,),
        (1,),
        (2,),
        (1, 0),
        (2, 0),
        (2, 1)
    }


def test_fr1_equals_fr2(graph5: DAG) -> None:
    fr1 = FR1(graph5)
    for comb in parts_of(graph5.X):
        fr1.run(comb)

    fr2 = FR2(graph5)
    for comb in parts_of(graph5.X):
        fr2.run(comb)

    assert fr1.fr_cache == fr2.decoded_fr_cache
    assert fr1.v_cache == fr2.decoded_v_cache
