import pytest
from pytest import raises

from torchjd.tree import EmptyTree, Leaf, Node, NonEmptyTree, Tree


@pytest.mark.parametrize(
    ["tree", "expected_depth"],
    [
        (Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4)), 3),
        (Node(Leaf(2)), 1),
        (Leaf(4), 0),
    ],
)
def test_depth(tree: NonEmptyTree, expected_depth: int):
    assert tree.depth() == expected_depth


def test_eq():
    tree1 = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))
    tree2 = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))

    assert tree1 == tree2


def test_map():
    tree = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))
    mapped = tree.map(lambda x: 2 * x + 1)
    expected = Node(Node(Leaf(3), Node(Leaf(5), Leaf(7))), Leaf(9))

    assert mapped == expected


def test_flatmap_leaf_1():
    tree = Leaf(1)
    mapped = tree.flatmap(lambda x: Leaf(2 * x + 1))
    expected = Leaf(3)

    assert mapped == expected


def test_flatmap_leaf_2():
    tree = Leaf(1)
    mapped = tree.flatmap(lambda x: Node(Leaf(2 * x + 1), Node(Leaf(3 * x - 1))))
    expected = Node(Leaf(3), Node(Leaf(2)))

    assert mapped == expected


def test_flatmap_node_1():
    tree = Node(Leaf(1), Leaf(2))
    mapped = tree.flatmap(lambda x: Leaf(2 * x + 1))
    expected = Node(Leaf(3), Leaf(5))

    assert mapped == expected


def test_flatmap_node_2():
    tree = Node(Leaf(1), Node(Leaf(2), Leaf(3)))
    mapped = tree.flatmap(lambda x: Node(Leaf(2 * x + 1), Node(Leaf(3 * x - 1))))
    expected = Node(
        Node(Leaf(3), Node(Leaf(2))),
        Node(Node(Leaf(5), Node(Leaf(5))), Node(Leaf(7), Node(Leaf(8)))),
    )

    assert mapped == expected


def test_flatmap_node_empty():
    tree = Node(Leaf(1), Node(Leaf(2), Leaf(3)))
    mapped = tree.flatmap(lambda x: EmptyTree())
    expected = EmptyTree()

    assert mapped == expected


@pytest.mark.parametrize(
    ["tree", "expected_str"],
    [
        (Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4)), "<<1, <2, 3>>, 4>"),
        (Node(Leaf(2)), "<2>"),
        (Leaf(4), "4"),
        (EmptyTree(), "EmptyTree"),
    ],
)
def test_str(tree: Tree, expected_str: str):
    assert str(tree) == expected_str


@pytest.mark.parametrize(
    ["tree", "expected_numbers"],
    [
        (Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4)), [1, 2, 3, 4]),
        (Node(Leaf(2)), [2]),
        (Leaf(4), [4]),
        (EmptyTree(), []),
    ],
)
def test_iter(tree: Tree, expected_numbers: list):
    numbers = [i for i in tree]

    assert numbers == expected_numbers


def test_zip_leafs():
    tree1 = Leaf(2)
    tree2 = Leaf(3)
    tree3 = Leaf(5)

    expected = Leaf((2, 3, 5))
    result = tree1.zip(tree2, tree3)

    assert result == expected


def test_zip_empty_trees():
    tree1 = EmptyTree()
    tree2 = EmptyTree()
    tree3 = EmptyTree()

    expected = EmptyTree()
    result = tree1.zip(tree2, tree3)

    assert result == expected


def test_zip_complex_trees():
    tree1 = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))
    tree2 = Node(Node(Leaf(10), Node(Leaf(20), Leaf(30))), Leaf(40))
    tree3 = Node(Node(Leaf(100), Node(Leaf(200), Leaf(300))), Leaf(400))

    expected = Node(
        Node(Leaf((1, 10, 100)), Node(Leaf((2, 20, 200)), Leaf((3, 30, 300)))), Leaf((4, 40, 400))
    )
    result = tree1.zip(tree2, tree3)

    assert result == expected


def test_zip_no_others():
    tree = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))

    expected = Node(Node(Leaf((1,)), Node(Leaf((2,)), Leaf((3,)))), Leaf((4,)))
    result = tree.zip()

    assert result == expected


def test_zip_different_structures_1():
    tree1 = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))
    tree2 = Node(Leaf(10), Node(Leaf(20), Leaf(30)), Leaf(40))

    with raises(ValueError):
        tree1.zip(tree2)


def test_zip_different_structures_2():
    tree1 = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4))
    tree2 = Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Node(Leaf(4)))

    # Technically, this should be a ValueError, but it would be harder to implement
    with raises(TypeError):
        tree1.zip(tree2)


@pytest.mark.parametrize(
    ["tree", "expected_length"],
    [
        (Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4)), 4),
        (Node(Leaf(2)), 1),
        (Leaf(4), 1),
        (EmptyTree(), 0),
    ],
)
def test_len(tree: Tree, expected_length: int):
    assert len(tree) == expected_length


@pytest.mark.parametrize(
    ["tree", "expected"],
    [
        (Node(Node(Leaf(1), Node(Leaf(2), Leaf(3))), Leaf(4)), Node(Node(Node(Leaf(2))), Leaf(4))),
        (Node(Leaf(2)), Node(Leaf(2))),
        (Leaf(3), EmptyTree()),
        (Leaf(6), Leaf(6)),
        (EmptyTree(), EmptyTree()),
    ],
)
def test_filter(tree: Tree, expected: int):
    output = tree.filter(lambda x: x % 2 == 0)
    assert output == expected
