import pytest
import numpy as np

from region_graphs import QuadTree, RandomQuadTree2, DecisionLeafQuadGraph


@pytest.mark.parametrize("num_repetitions", [1, 12])
def test_quad_tree(num_repetitions: int):
    shape = (1, 2, 2)
    num_variables = shape[0] * shape[1] * shape[2]
    rg = QuadTree(shape, num_repetitions=num_repetitions, num_patch_splits=4)

    leaf_nodes = (n for n in rg.region_nodes if len(n.scope) == 1)
    leaf_scopes = list(map(lambda n: tuple(n.scope)[0], leaf_nodes))
    vars, counts = np.unique(leaf_scopes, return_counts=True)
    assert np.all(vars == np.arange(num_variables))
    assert np.all(counts == num_repetitions)

    assert rg.is_structured_decomposable


@pytest.mark.parametrize("num_repetitions", [1, 12])
def test_random_quad_tree_2(num_repetitions: int):
    shape = (1, 7, 7)
    num_variables = shape[0] * shape[1] * shape[2]
    rg = RandomQuadTree2(shape, num_repetitions=num_repetitions, seed=42)

    leaf_nodes = (n for n in rg.region_nodes if len(n.scope) == 1)
    leaf_scopes = list(map(lambda n: tuple(n.scope)[0], leaf_nodes))
    vars, counts = np.unique(leaf_scopes, return_counts=True)
    assert np.all(vars == np.arange(num_variables))
    assert np.all(counts == num_repetitions)

    if num_repetitions == 1:
        assert rg.is_structured_decomposable
    else:
        assert not rg.is_structured_decomposable


def test_decleaf_quad_graph():
    num_channels = 3
    shape = (num_channels, 32, 32)
    num_variables = shape[0] * shape[1] * shape[2]
    rg = DecisionLeafQuadGraph(shape, max_patch_size=8)

    assert shape[1] == shape[2]
    leaf_nodes = (n for n in rg.region_nodes if len(n.scope) == num_channels)
    leaf_scopes = list(map(lambda n: tuple(n.scope), leaf_nodes))
    vars, counts = np.unique(leaf_scopes, return_counts=True)
    assert np.all(vars == np.arange(num_variables))

    assert not rg.is_structured_decomposable

