import unittest
import warnings
from typing import Any, Dict, List, Tuple
from collections import namedtuple

import pytest
import numpy as np

import cdd
import decomp
import tools
import inversion

Polytope = Dict[str, Any]
HRepresentation = Dict[str, Any]
VRepresentation = Dict[str, Any]
Region = List[Polytope]

np.set_printoptions(linewidth=1000)


def _prepend_zero_row(x: np.ndarray) -> np.ndarray:
    assert 2 == x.ndim
    nc = x.shape[1]
    x0 = np.vstack([np.zeros((1, nc)), x])
    return x0


def _wrapped_checker(m1: np.ndarray, m2: np.ndarray) -> bool:
    """ Appending a zero ray of zeros does not affect the polytope
    (actually to either h or v representation)
    """
    m1_0 = _prepend_zero_row(m1)
    m2_0 = _prepend_zero_row(m2)

    return tools.same_unique_rows(m1_0, m2_0)


def test0():
    """
    What is the subset of [0, infty)^2 such that
    [[1, 1]] relu(x - [[1], [1]]) <= 1?
    Obvious first observations:
      - [[0, 2]] is in it, in fact [[1, 2]] is in it
      - [[2, 0]] is in it, in fact [[2, 1]] is in it
      - (1.5, 1.5) is in it, but (1.5, 1.51) is not
      = (1.49, 1.51) is in it

    Conditional on
      - no relus binding
      x0 > 1, x1 > 1, x0 + x1 < 3
      - first relu not binding, second binding
      x0 > 1, x1 < 1, x0 < 2 [x1 > 0 implicit]
      - first relu binding, second not binding
      x0 < 1, x1 > 0, x1 < 2 [x0 > 0 implicit]
      - both relus binding
      x0 < 1, x1 < 1, [x0 > 0, x1 > 0 implicit]
    """
    c = np.array([[+1, -1, -1]])

    w = np.array([[1.0, 0.0], [0.0, 1.0]])
    b = np.array([[-1.0], [-1.0]])

    c_lin = np.empty((0, c.shape[1]))
    c_v_repr = tools.h_to_v(c, c_lin)

    h = dict(inequality=c, linear=c_lin, is_empty=False)
    v = dict(vertices=c_v_repr, is_empty=False)
    image = dict(h=h, v=v)

    x_lower = np.zeros((2, 1))
    x_upper = np.full((2, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)
    is_rational = False
    need_v = True

    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                              image,
                                                              w,
                                                              b,
                                                              x_lower,
                                                              x_upper,
                                                              is_rational,
                                                              need_v)

    decomp = [_["vertices"] for _ in relu_decomp_vs]
    expected_h0 = np.array([[+3, -1, -1.0], [-1, +1, +0.0], [-1, +0, +1]])
    expected_v1 = np.array(
        [[1.0, 0.0, 2.0], [1.0, 0.0, 1.0], [1.0, 1.0, 2.0], [1.0, 1.0, 1.0]]
    )
    expected_v2 = np.array([[1, 1, 0], [1, 2, 0], [1, 1, 1], [1, 2, 1]])
    expected_v3 = np.array([[1, 1, 0], [1, 1, 1], [1, 0, 1], [1, 0, 0]])

    actual_h0 = tools.v_to_h(decomp[0], None)
    assert _wrapped_checker(expected_h0, actual_h0)
    assert _wrapped_checker(decomp[1], expected_v1)
    assert _wrapped_checker(decomp[2], expected_v2)
    assert _wrapped_checker(decomp[3], expected_v3)


def test1():
    """
    What are the x in [0, 1)^2 such that the first coordinate of
    relu(x - [[1], [1]])
    is weakly greater than the second?

    Obviously:
      [0, 1]^2 is contained, since this all gets mapped to [[0], [0]]
      [[x], [1.0001]] is not, for any x < 1.0001
      [[x+s], [x]] is, for x > 1, s > 0.

    This is a simple test of the ray (unbounded bits).
    """
    c = np.array([[+0, +1, -1]])
    w = np.array([[1.0, 0.0], [0.0, 1.0]])
    b = np.array([[-1.0], [-1.0]])

    c_lin = np.empty((0, c.shape[1]))
    c_v_repr = tools.h_to_v(c, c_lin)

    h = dict(inequality=c, linear=c_lin, is_empty=False)
    v = dict(vertices=c_v_repr, is_empty=False)
    image = dict(h=h, v=v)

    x_lower = np.zeros((2, 1))
    x_upper = np.full((2, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)
    is_rational = False
    need_v = True
    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                                  image,
                                                                  w,
                                                                  b,
                                                                  x_lower,
                                                                  x_upper,
                                                                  is_rational,
                                                                  need_v)
    decomp = [_["vertices"] for _ in relu_decomp_vs]
    # x > y and x > 1
    expected_h0 = np.array([[+0, +1, -1], [-1, +0, +1]])
    expected_h0_lin = np.empty((0, expected_h0.shape[1]))
    expected_v0 = tools.h_to_v(expected_h0, expected_h0_lin)
    assert _wrapped_checker(expected_v0, decomp[0])

    # the line between [[0], [1]] and [[1], [1]]
    expected_v1 = np.array([[+1, +0, +1], [+1, +1, +1]])
    assert _wrapped_checker(expected_v1, decomp[1])

    # the infinitely-long rectangle [1, infty) x [0, 1]
    expected_h2 = np.array([[+1, +0, -1], [-1, +1, +0], [+0, +0, +1]])
    expected_h2_lin = np.empty((0, expected_h2.shape[1]))
    expected_v2 = tools.h_to_v(expected_h2, expected_h2_lin)
    assert _wrapped_checker(expected_v2, decomp[2])

    # both relued => [0, 1]^2
    expected_v3 = np.array([[+1, +1, +0], [+1, +0, +0], [+1, +0, +1], [+1, +1, +1]])
    assert _wrapped_checker(expected_v3, decomp[3])


def test2():
    """
    Sort-of tricky example of negative elements in w:

    halfplane below y = .5 + x

    Observations:
      - for x, y large enough, obviously everything will be thresholded
        thus, we expect "most" of the 0 activation pattern to satisfy.
      - x and y are not symmetric
      - both will be positive and satisfiy when x < 1, y < 1, and
        y > .5 + x
    """
    c = np.array([[+0.5, +1, -1]])
    w = np.array([[-1.0, 0.0], [0.0, -1.0]])
    b = np.array([[+1.0], [+1.0]])
    n = w.shape[1]

    c_lin = np.empty((0, c.shape[1]))
    c_v_repr = tools.h_to_v(c, c_lin)

    h = dict(inequality=c, linear=c_lin, is_empty=False)
    v = dict(vertices=c_v_repr, is_empty=False)
    image = dict(h=h, v=v)

    x_lower = np.zeros((n, 1))
    x_upper = +1 * np.full((n, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)
    is_rational = False
    need_v = True
    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                                  image,
                                                                  w,
                                                                  b,
                                                                  x_lower,
                                                                  x_upper,
                                                                  is_rational,
                                                                  need_v)
    expected_h0 = np.array(
        [
            [+1, -2, +2],  # y >= .5 + x
            [+0, +1, +0],  # x >= 0
            [+0, +0, +1],  # y >= 0
            [+1, +0, -1],  # y <= 1
            [+1, -1, +0],
        ]
    )  # x <= 1
    decomp = [_["vertices"] for _ in relu_decomp_vs]
    expected_h0_lin = np.empty((0, expected_h0.shape[1]))
    expected_v0 = tools.h_to_v(expected_h0, expected_h0_lin)
    assert _wrapped_checker(expected_v0, decomp[0])

    # y <= 1 means that it won't be thresholded, so this is
    # x > 1, y <= 1, but we need y >= .5
    expected_h1 = np.array([[+1, +0, -1], [-1, +1, +0], [-1, +0, +2]])
    expected_h1_lin = np.empty((0, expected_h1.shape[1]))
    expected_v1 = tools.h_to_v(expected_h1, expected_h1_lin)
    assert _wrapped_checker(expected_v1, decomp[1])
    # same reasoning as above, though x no constraint on x
    # so this is just [0, 1] x [1, infty)
    expected_h2 = np.array([[+1, -1, +0], [+0, +1, +0], [-1, +0, +1]])
    expected_h2_lin = np.empty((0, expected_h2.shape[1]))
    expected_v2 = tools.h_to_v(expected_h2, expected_h2_lin)
    assert _wrapped_checker(expected_v2, decomp[2])

    # 0 < .5, so this is [1, infty)^2
    expected_h3 = np.array([[-1, +1, +0], [-1, +0, +1]])
    expected_h3_lin = np.empty((0, expected_h3.shape[1]))
    expected_v3 = tools.h_to_v(expected_h3, expected_h3_lin)
    assert _wrapped_checker(expected_v3, decomp[3])


def test_xor():
    """
    Simple disconnected region example.
    relu(x + y) + 2 * relu(x + y - 1) >= .5
    See: http://www.cs.columbia.edu/~mcollins/ff.pdf

    Observations:
        - The neighbourhood of [[0], [0]] is not included, since the two need to
          sum to more than .5.
        - [[.5], [0]] is a boundary point, where the first term nonzero
        - [[0], [.5]] is another boundary point, where the first term nonzero
        - actually, any (x, y) where .5 < x + y < 1.0 will have the first term
          nonzero and the second term thresholded.
        - Note: it is impossible for the first term to be thresholded, and
          the second one to be positive, since the pre-activation values of
          the first are strictly greater than those of the second.
    """
    c = np.array([[-0.5, +1.0, +2.0]])
    w = np.array([[+1.0, +1.0], [+1.0, +1.0]])
    b = np.array([[0], [-1]])
    c_lin = np.empty((0, c.shape[1]))
    c_v_repr = tools.h_to_v(c, c_lin)

    h = dict(inequality=c, linear=c_lin, is_empty=False)
    v = dict(vertices=c_v_repr, is_empty=False)
    image = dict(h=h, v=v)

    n = w.shape[1]

    x_lower = np.zeros((n, 1))
    x_upper = np.full((n, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)

    is_rational = False
    need_v = True
    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                                  image,
                                                                  w,
                                                                  b,
                                                                  x_lower,
                                                                  x_upper,
                                                                  is_rational,
                                                                  need_v)
    decomp = [_["vertices"] for _ in relu_decomp_vs]
    expected_h0 = np.array([[+0, +1, +0], [+0, +0, +1], [-1, +1, +1]])
    expected_h0_lin = np.empty((0, expected_h0.shape[1]))
    expected_v0 = tools.h_to_v(expected_h0, expected_h0_lin)
    assert _wrapped_checker(expected_v0, decomp[0])

    expected_v1 = np.empty((0, 3))
    assert _wrapped_checker(expected_v1, decomp[1])

    expected_h2 = np.array([[-1, +2, +2], [+0, +1, +0], [+1, -1, -1], [+0, +0, +1]])
    expected_h2_lin = np.empty((0, expected_h2.shape[1]))
    expected_v2 = tools.h_to_v(expected_h2, expected_h2_lin)
    assert _wrapped_checker(expected_v2, decomp[2])

    expected_v3 = np.empty((0, 3))
    assert _wrapped_checker(expected_v3, decomp[3])


def test_empty():
    """ Test that we get entirely empty decompositions when we expect """
    # There are no (x, y) in [0, infty)^2 such that this ca
    c = np.array([[-0.0001, +1, +1]])
    w = np.array([[-1.0, 0.0], [0.0, -1.0]])
    b = np.array([[+0.0], [+0.0]])
    c_lin = np.empty((0, c.shape[1]))
    c_v_repr = tools.h_to_v(c, c_lin)

    h = dict(inequality=c, linear=c_lin, is_empty=False)
    v = dict(vertices=c_v_repr, is_empty=False)
    image = dict(h=h, v=v)

    n = w.shape[1]

    x_lower = np.zeros((n, 1))
    x_upper = np.full((n, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)
    is_rational = False
    need_v = True
    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                                  image,
                                                                  w,
                                                                  b,
                                                                  x_lower,
                                                                  x_upper,
                                                                  is_rational,
                                                                  need_v)
    decomp = [_["vertices"] for _ in relu_decomp_vs]
    assert all([0 == x.size for x in decomp])


def test_simple_iterated_decomposition():
    """
    As a reduced form of a repeated decomposition, consider now decomposing
    something which is not a halfspace.

    Namely, an axis-oriented unit diamond centered at 1, 1.

    Observations:
      - With no thresholding, this will just be the same, shifted up by
        (1, 1)
      - Since the set touches zero at (0, 1) and (1, 0) there is some
        narrow scope for preimages of these sparsity patterns.
      - [Same argument as above] zero is not contained in the set, so
        nothing in the (0, 0) sparsity pattern.
    """

    v_repr = np.array([[+1, +1, +2],
                       [+1, +2, +1],
                       [+1, +1, +0],
                       [+1, +0, +1]])

    v_lin = np.empty((0, v_repr.shape[1]))
    h_ineq = tools.v_to_h(v_repr, v_lin, False)
    h_lin = np.empty((0, h_ineq.shape[1]))
    v = dict(vertices=v_repr, is_empty=False)
    h = dict(inequality=h_ineq, linear=h_lin, is_empty=False)

    image = dict(h=h, v=v)

    w = np.array([[1.0, 0.0], [0.0, 1.0]])
    b = np.array([[-1.0], [-1.0]])

    n = w.shape[1]

    x_lower = np.zeros((n, 1))
    x_upper = np.full((n, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)
    is_rational = False
    need_v = True

    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                                  image,
                                                                  w,
                                                                  b,
                                                                  x_lower,
                                                                  x_upper,
                                                                  is_rational,
                                                                  need_v)
    decomp = [_["vertices"] for _ in relu_decomp_vs]
    expected_v0 = v_repr + np.ones((4, 1)) * np.array([[0, 1, 1]])
    _wrapped_checker(expected_v0, decomp[0])

    expected_v1 = np.array([[+1, +0, +2], [+1, +1, +2]])
    _wrapped_checker(expected_v1, decomp[1])

    expected_v2 = np.array([[+1, +2, +0], [+1, +2, +1]])
    _wrapped_checker(expected_v2, decomp[2])

    expected_v3 = np.empty((0, 3))
    _wrapped_checker(expected_v3, decomp[3])


def test_3d():
    """
    Let's move beyond two dimensions now.
    Since we cannot plot things as easily, let's start with the simples possible case.

    Unit cube in 3d, unit mapping.

    Observations:
      - The un-thresholded sparsity pattern generates the whole image.
      - The whole preimage is itself just the unit cube.
      - zeroed dimensions correspond to faces and lines and corners of the cube.
    """

    v_repr = np.array(
        [
            [+1, +0, +0, +0],
            [+1, +0, +0, +1],
            [+1, +0, +1, +0],
            [+1, +1, +0, +0],
            [+1, +0, +1, +1],
            [+1, +1, +0, +1],
            [+1, +1, +1, +0],
            [+1, +1, +1, +1],
        ]
    )

    v = dict(vertices=v_repr, is_empty=False)
    v_lin = np.empty((0, v_repr.shape[1]))
    h_repr = tools.v_to_h(v["vertices"], v_lin, False)

    h = dict(inequality=h_repr, linear=np.empty((0, h_repr.shape[1])), is_empty=False)
    image = dict(h=h, v=v)

    w = np.eye(3)
    b = np.zeros((3, 1))
    x_lower = np.zeros((3, 1))
    x_upper = np.full((3, 1), np.inf)

    n = w.shape[0]
    index = tuple(range(n))
    index_powerset = tools.powerset(index)
    is_rational = False
    need_v = True
    relu_decomp_vs, relu_decomp_hs = inversion.relu_decomposition(index_powerset,
                                                                  image,
                                                                  w,
                                                                  b,
                                                                  x_lower,
                                                                  x_upper,
                                                                  is_rational,
                                                                  need_v)
    assert _wrapped_checker(v_repr, relu_decomp_vs[0]["vertices"])

    origin_3d = np.array([[+1, +0, +0, +0]])
    assert _wrapped_checker(origin_3d, relu_decomp_vs[-1]["vertices"])

#
# def get_test_inputs1() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
#     w_ = np.array(
#         [
#             [-0.01905135, +0.69748960],
#             [-1.96483330, -0.83005710],
#             [-0.10843029, +0.03259321],
#             [-0.18028592, +1.08891880],
#             [+0.11887044, +0.83100160],
#             [-3.10236450, -1.32591560],
#             [+0.18438303, -1.20181450],
#             [-0.29148576, +0.02619374],
#             [-0.03064194, +0.66856210],
#             [-2.14885540, -1.17508550],
#         ]
#     )
#
#     b_ = np.array(
#         [
#             [+0.67673033],
#             [-0.47278836],
#             [-0.30218090],
#             [1.26878170],
#             [0.60808855],
#             [-0.76127154],
#             [1.67438660],
#             [-0.65596200],
#             [-0.12204660],
#             [-0.32245210],
#         ]
#     )
#
#     a_ = np.array(
#         [
#             [+1.0, 1.80199382, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [-0.0, -1.00000000, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [+0.0, -8.33034471, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [+0.0, -0.04171375, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [+0.0, -1.48250181, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [+0.0, -1.70009400, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [+0.0, -25.05845022, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
#             [+0.0, 2.31168188, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
#             [+0.0, -0.11294138, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
#             [+0.0, -0.33199848, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
#             [+0.0, -5.36232103, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
#             [-0.0, 8.33034471, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, 0.04171375, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, 1.48250181, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, 1.70009400, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, 25.05845022, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, -2.31168188, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0],
#             [-0.0, 0.11294138, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0],
#             [-0.0, 0.33199848, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0],
#             [-0.0, 5.36232103, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0],
#         ]
#     )
#     is_relued_ = np.array(
#         [True, False, True, False, True, False, False, True, True, False]
#     )
#     return w_, b_, a_, is_relued_
#
# #
# def get_test_inputs2() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
#     is_relued_ = np.array(
#         [True, False, False, False, False, True, False, True, False, True]
#     )
#     b_ = np.array(
#         [
#             [1.8356802],
#             [-5.4187746],
#             [-1.705672],
#             [-1.3153728],
#             [-0.9695644],
#             [-0.5793407],
#             [1.8497604],
#             [3.7820776],
#             [-2.5795283],
#             [0.38880357],
#         ]
#     )
#     w_ = np.array(
#         [
#             [-1.7420893, -0.6828302],
#             [4.64374, 0.8248184],
#             [0.55591327, -10.933058],
#             [3.2613926, 4.365506],
#             [1.9887098, 2.9086561],
#             [0.1505127, 0.09131753],
#             [-3.2723374, -5.165686],
#             [-3.7218924, -0.65449214],
#             [3.948692, 1.4463147],
#             [-6.889521, 2.580719],
#         ]
#     )
#     a_ = np.array(
#         [
#             [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [1.0, 4.072712, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [0.0, -5.123364, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [0.0, -4.538653, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [0.0, 2.171898, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [0.0, 0.986171, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
#             [0.0, -0.043758, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
#             [0.0, 4.232144, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
#             [0.0, -3.924413, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
#             [0.0, -4.260749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
#             [0.0, 3.339741, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
#             [-0.0, 5.123364, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, 4.538653, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, -2.171898, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, -0.986171, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, 0.043758, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0, -0.0],
#             [-0.0, -4.232144, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0, -0.0],
#             [-0.0, 3.924413, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0, -0.0],
#             [-0.0, 4.260749, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0, -0.0],
#             [-0.0, -3.339741, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -1.0],
#         ]
#     )
#     return w_, b_, a_, is_relued_


def test_rational():
    point = np.array([[-0.3439656, 0.00813415]])
    point_h_rational = tools.build_h_repr_of_point(point, True)
    point_h_float = tools.build_h_repr_of_point(point, False)

    np.testing.assert_allclose(point_h_rational, point_h_float)
    point_h_rational_lin = np.empty((0, point_h_rational.shape[1]))
    v_recovered1 = tools.h_to_v(point_h_rational, point_h_rational_lin, True)
    np.testing.assert_allclose(point, v_recovered1[:, 1:])


def test_invert_relu_layer1():
    # n = 3
    n = 2
    v = tools.unit_cube_v_repr(n)
    v_center = np.mean(v[:, 1:], axis=0)
    v11 = np.hstack((tools.vec(v[:, 0]), (v[:, 1:] - v_center) * 2))
    v_lin = np.empty((0, v11.shape[1]))

    h_ineq = tools.v_to_h(v11, v_lin)
    h_lin = np.empty((0, h_ineq.shape[0] - 1))

    v = tools.h_to_v(h_ineq, h_lin)
    h = dict(inequality=h_ineq, linear=h_lin, is_empty=False)
    p = dict(h=h, v=None)
    r1 = [p]

    # invert_layer = None
    is_rational = True
    need_v = True
    # x_lower = -1 * np.full((n, 1), np.inf)
    # x_upper = +1 * np.full((n, 1), np.inf)
    to_flatten = [inversion.invert_relu_layer_kernel(p, is_rational, need_v) for p in r1]
    r0 = _flatten_list_of_lists(to_flatten)

    q3 = np.array([[ 1.,  0.,  0.],
                   [ 0., -1.,  0.],
                   [-0., -0., -1.]])  # third quadrant (-inf, 0]^2

    # vert_strip = np.array([[ 1.,  1.,  0.],
    #                        [ 1., -1.,  0.],
    #                        [-0., -0., -1.]])  # [0, +1] x (-inf, 0]

    vert_strip = np.array([[1., 1., 0.],
                           [1., 0., 0.],
                           [-0., -0., -1.]])

    horz_strip = np.array([[ 1.,  0.,  1.],
                           [ 1.,  0.,  0.],
                           [-0., -1., -0.]])  # (-inf, 0] x [-1, +1]

    v0_0 = r0[0]["v"]["vertices"]
    v0_1 = r0[1]["v"]["vertices"]
    v0_2 = r0[2]["v"]["vertices"]
    v0_3 = r0[3]["v"]["vertices"]

    assert _wrapped_checker(v0_0, q3)
    assert _wrapped_checker(v0_1, vert_strip)
    assert _wrapped_checker(v0_2, horz_strip)
    assert _wrapped_checker(v0_3, tools.unit_cube_v_repr(2))

    if False:
        import plotting

        lv_repr = [v0_0, v0_1, v0_2, v0_3]
        plotting.convex_hull_plot_simple_vectorised(lv_repr)

        actual = [q3, vert_strip, horz_strip, v]
        plotting.convex_hull_plot_simple_vectorised(actual)


def _flatten_list_of_lists(ll: List[list]) -> list:
    flattened = [x for l in ll for x in l]
    return flattened


def test_invert_relu_layer2():
    # # n = 3
    # n = 2
    # v = tools.unit_cube_v_repr(n)
    #

    v_repr = np.array([[+1, +1, +2],
                       [+1, +2, +1],
                       [+1, +1, +0],
                       [+1, +0, +1]])
    if False:
        import plotting
        plotting.convex_hull_plot_simple(v_repr)

    v_lin = np.empty((0, v_repr.shape[0]))

    h_ineq = tools.v_to_h(v_repr, v_lin, False)

    h_lin = np.empty((0, h_ineq.shape[1]))
    h = dict(inequality=h_ineq, linear=h_lin, is_empty=False)
    v = dict(vertices=v_repr, is_empty=False)

    w = np.array([[2.0, 0.0], [0.0, 1.0]])
    b = np.array([[+1.0], [+1.5]])

    in_dim = w.shape[1]

    x_lower = -1 * np.full((in_dim, 1), np.inf)
    x_upper = +1 * np.full((in_dim, 1), np.inf)
    is_rational = True
    # need_v = False
    need_v = True

    p = dict(h=h, v=v)
    r2 = [p]
    r1 = decomp.invert_linear_layer_kernel(w,
                                            b,
                                            r2,
                                            x_lower,
                                            x_upper,
                                            is_rational,
                                            need_v)
    to_flatten = [inversion.invert_relu_layer_kernel(p, is_rational, need_v) for p in r1]
    r0 = _flatten_list_of_lists(to_flatten)

    w = np.eye(2)
    b = np.zeros((2, 1))
    r_expected = decomp.invert_linearthenrelu_layer_kernel(w,
                                                            b,
                                                            r1,
                                                            x_lower,
                                                            x_upper,
                                                            is_rational,
                                                            need_v)

    r0_vlist = list(reversed([_["v"]["vertices"] for _ in r0]))
    # r0_vlist = [_["v"]["vertices"] for _ in r0]
    r_expected_vlist = [_["v"]["vertices"] for _ in r_expected]

    for idx in range(len(r0_vlist)):
        assert _wrapped_checker(r0_vlist[idx], r_expected_vlist[idx])

    do_plot = False
    # do_plot = True
    if do_plot:
        import plotting

        plotting.convex_hull_plot_simple_vectorised(r0_vlist)
        plotting.convex_hull_plot_simple_vectorised(r_expected_vlist)


if __name__ == "__main__":
    test0()
    test1()
    test_simple_iterated_decomposition()

    test_invert_relu_layer1()
    test_invert_relu_layer2()
    test_3d()
    # test0()
    # test1()
    # test2()
    # test_xor()
    # test_empty()
    #
    # test_simple_iterated_decomposition()
    # test_3d()
    #
    # test_rational()
    # test_numerical_problem2()
    # test_cdd_problem2()
    # _test_numerical_problem2()
    # pass
