import operator

import numpy as np
import scipy.spatial

import cvxpy

import tools

np.random.seed(123456)


def test_same_unique_rows():
    nr = 10
    nc = 3
    m1 = np.random.rand(nr, nc)
    i2 = np.random.permutation(nr)
    m2 = m1[i2, :]
    assert tools.same_unique_rows(m1, m2)


def test_v_linear():
    v_one_side = np.array([[1, 1, 1],
                           [1, 0, 1],
                           [1, 1, 0]])
    v = np.vstack([v_one_side, -1 * v_one_side])
    if False:
        import plotting
        plotting.convex_hull_plot_simple(v)
    h = tools.v_to_h(v, None)


def test_h_to_v0():
    """
    (x, y) st
        0 <= 1 + x - y
        0 <= 1 - x - y
        0 <= 1 + x + y
        0 <= 1 - x + y

    is the diamond generated by
    [[+1, +0, +1],
     [+1, +1, +0],
     [+1, +0, -1],
     [+1, -1, 0]]
    """
    h_repr = np.array([[+1, +1, -1],
                       [+1, -1, -1],
                       [+1, +1, +1],
                       [+1, -1, +1]])
    h_repr_lin = np.empty((0, h_repr.shape[1]))
    v_repr = tools.h_to_v(h_repr, h_repr_lin)
    expected_v_repr = np.array([[+1, +0, +1],
                                [+1, +1, +0],
                                [+1, +0, -1],
                                [+1, -1, +0]])
    assert tools.same_unique_rows(v_repr, expected_v_repr)


def test_h_to_v1():
    """
    The halfspace where y < x is generated by
          [[+0, +1, +0],
           [+0, +1, +1],
           [+0, -1, -1]]

    In words:
    It consists of positive combinations of
      [[+1], [+1]]
      [[-1], [-1]] and
      [[+1], [+0]]

    In pictures:


                    |     /*
                    |    /**
                    |   /***
                    |  /****
                    | /*****
                    |/******
             -------+-------
                   /|*******
                  /*|*******
                 /**|*******
                /***|*******
               /****|*******
              /*****|*******

    is generated by

                    |     ^
                    |    /
                    |   /
                    |  /
                    | /
                    |/
             -------+------>
                   /|
                  / |
                 /  |
                /   |
               /    |
              V     |
    """
    h_repr = np.array([[0, +1, -1]])
    h_repr_lin = np.empty((0, h_repr.shape[1]))
    v_repr = tools.h_to_v(h_repr, h_repr_lin)
    expected_v_repr = np.array([[+0, +1, +0], [+0, +1, +1], [+0, -1, -1]])
    assert tools.same_unique_rows(v_repr, expected_v_repr)


def test_v_to_h0():
    # The example from:
    # http://eaton.math.rpi.edu/faculty/Mitchell/courses/matp6640/notes/02n_resolutionB/02n_resolutionB.pdf
    #   Note that the last equation in the H representation,
    #   2x_1 - 2x_2 >= -14 can, and probably should be, reduced to
    #    x_1 -  x_2 >= -7

    v_repr = np.array([[0, 4, 1], [0, 2, 2], [1, 5, 1], [1, 3, 2], [1, 1, 8]])
    expetected_h_repr = np.array(
        [[-11, +3, +1], [+7, +1, -1], [+1, -1, +4], [-7, +1, +2]]
    )
    tools.same_unique_rows(v_repr, expetected_h_repr)


def assert_allclose_unsorted(v1: np.ndarray, v2: np.ndarray) -> None:
    nr, dim = v1.shape
    assert (nr, dim) == v2.shape, "Cannot be close if they are not the same size"

    key = operator.itemgetter(*tuple(range(dim)))
    sv1 = np.array(sorted(v1, key=key))
    sv2 = np.array(sorted(v2, key=key))
    np.testing.assert_allclose(sv1, sv2)


def test_simple_recovery():
    # based on: https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.spatial.ConvexHull.html
    points = np.random.rand(30, 3)
    hull = scipy.spatial.ConvexHull(points)
    num_vertices = len(hull.vertices)

    v_repr = np.hstack([np.ones((num_vertices, 1)), points[hull.vertices, :]])
    h_repr = tools.v_to_h(v_repr, None)
    h_repr_lin = np.empty((0, h_repr.shape[1]))
    v_repr_recovered = tools.h_to_v(h_repr, h_repr_lin)
    assert_allclose_unsorted(v_repr, v_repr_recovered)


def test_unit_cube_4d():
    expected = np.array(
        [
            [1.0, 0.0, 0.0, 0.0, 0.0],
            [1.0, 1.0, 0.0, 0.0, 0.0],
            [1.0, 0.0, 1.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 1.0, 0.0],
            [1.0, 0.0, 0.0, 0.0, 1.0],
            [1.0, 1.0, 1.0, 0.0, 0.0],
            [1.0, 1.0, 0.0, 1.0, 0.0],
            [1.0, 1.0, 0.0, 0.0, 1.0],
            [1.0, 0.0, 1.0, 1.0, 0.0],
            [1.0, 0.0, 1.0, 0.0, 1.0],
            [1.0, 0.0, 0.0, 1.0, 1.0],
            [1.0, 1.0, 1.0, 1.0, 0.0],
            [1.0, 1.0, 1.0, 0.0, 1.0],
            [1.0, 1.0, 0.0, 1.0, 1.0],
            [1.0, 0.0, 1.0, 1.0, 1.0],
            [1.0, 1.0, 1.0, 1.0, 1.0],
        ]
    )
    actual = tools.unit_cube_v_repr(4)
    assert tools.same_unique_rows(expected, actual)


def test_intersection1():
    v1 = tools.unit_cube_v_repr(3)
    v2 = tools.unit_cube_v_repr(3)

    intersection = tools.intersect_v_reprs(v1, v2)
    assert tools.same_unique_rows(v1, intersection)


def test_intersection2():
    v1 = tools.unit_cube_v_repr(3)

    h2 = np.array([[0.5, -1, 0, 0]])
    h2_lin = np.empty((0, h2.shape[1]))
    v2 = tools.h_to_v(h2, h2_lin)

    intersection = tools.intersect_v_reprs(v1, v2)
    expected = np.array(
        [
            [+1.0, +0.5, +1.0, +0.0],
            [+1.0, +0.0, +1.0, +0.0],
            [+1.0, +0.5, +0.0, +0.0],
            [+1.0, +0.0, +0.0, +0.0],
            [+1.0, +0.0, +0.0, +1.0],
            [+1.0, +0.5, +0.0, +1.0],
            [+1.0, +0.0, +1.0, +1.0],
            [+1.0, +0.5, +1.0, +1.0],
        ]
    )
    assert tools.same_unique_rows(expected, intersection)


def test_coalesce():
    c0 = tools.unit_cube_v_repr(2)
    just_vertices = np.array([[0, 1, 1]])
    just_col1 = np.array([[0, 1, 0]])
    c1 = c0 + just_vertices
    c2 = c0 + just_col1

    # c0 union c1 is not convex
    # c1 union c2 is convex
    # c0 union c2 is convex
    if False:
        import plotting
        fig, ax = plotting.list_of_convex_hull_plot_simple([c0, c1, c2])
        ax.text(*tuple(np.mean(c0[:, 1:], axis=0)), "c0")
        ax.text(*tuple(np.mean(c1[:, 1:], axis=0)), "c1")
        ax.text(*tuple(np.mean(c2[:, 1:], axis=0)), "c2")

    expected_c12 = np.array([[1, 2, 0],
                             [1, 1, 0],
                             [1, 1, 2],
                             [1, 2, 2]])
    expected_c02 = np.array([[1, 0, 0],
                             [1, 0, 1],
                             [1, 2, 0],
                             [1, 2, 1]])

    c01 = tools.union_v_reprs_if_convex(c0, c1)
    assert c01 is None

    c02 = tools.union_v_reprs_if_convex(c0, c2)
    np.testing.assert_allclose(expected_c02, c02)

    c12 = tools.union_v_reprs_if_convex(c1, c2)
    np.testing.assert_allclose(expected_c12, c12)


def test_canonicalize_h():
    h = np.array([[2, 1, 1],
                  [1, 1, 1]])
    expected_h = np.array([[1, 1, 1]])
    canonicalized_h = tools.canonicalize_h_form(h)
    np.testing.assert_allclose(canonicalized_h, expected_h)

    h = np.array([[1 + 1e-4, 1, 1],
                  [1, 1, 1]])
    expected_h = np.array([[1, 1, 1]])
    canonicalized_h = tools.canonicalize_h_form(h)
    np.testing.assert_allclose(canonicalized_h, expected_h)


def test_coalesce1():
    v1 = np.array([[1., 0.10043197, -0.22050298],
                   [1., 0.13243185, -0.19660494],
                   [1., 0.56677966, -0.36431513],
                   [1., 0.10969621, -0.54239196]])
    v2 = np.array([[1., 0.6185182 , +0.16641233],
                   [1., 0.13243185, -0.19660494],
                   [1., 0.56677966, -0.36431513]])

    assert np.all(v1[1, :] == v2[1, :])
    assert np.all(v1[2, :] == v2[2, :])

    h1 = tools.v_to_h(v1, None, True)
    h2 = tools.v_to_h(v2, None, True)

    actual_h = tools.union_h_reprs_if_convex(h1, h2)
    actual_v = tools.union_v_reprs_if_convex(v1, v2)

    tol = 1e-7
    convex_union_v_redundant = tools.convex_union_v_reprs(v1, v2)
    convex_union_h_redundant = tools.v_to_h(convex_union_v_redundant, None, True)

    # Not (necessarily) canonical, but at least none proportional
    convex_union_h = tools.drop_rows_positive_proportional_to_another(convex_union_h_redundant, tol)
    convex_union_h_lin = np.empty((0, convex_union_h.shape[1]))
    convex_union_v = tools.h_to_v(convex_union_h, convex_union_h_lin)

    actual_v_r = np.around(actual_v, 12)
    convex_union_v_r = np.around(convex_union_v, 12)

    assert tools.same_unique_rows(actual_v_r, convex_union_v_r)

    actual_h_r = np.around(actual_h, 11)
    convex_union_h_r = np.around(convex_union_h, 11)

    assert tools.same_unique_rows(actual_h_r, convex_union_h_r)


def test_envelope():
    """ Example from
    F. Borrelli, A. Bemporad, M. Morari
    Predictive Control for linear and hybrid systems
    """
    v1 = np.array([[+1, -6, +0],
                   [+1, -3, -6],
                   [+1, -3, +6],
                   [+1,  0,  0]])

    v2 = np.array([[+1, +0, +0],
                   [+1, +3, +6],
                   [+1, +3, -6],
                   [+1, +6, +0]])

    h1 = tools.v_to_h(v1, None)
    h2 = tools.v_to_h(v2, None)
    expected_envelope_v = np.array([[+1, -6, +0],
                                    [+1, +0, +12],
                                    [+1, +0, -12],
                                    [+1, +6, +0]])
    e, i1, i2 = tools.envelope(h1, h2)
    e_lin = np.empty((0, e.shape[1]))
    v = tools.h_to_v(e, e_lin)
    assert tools.same_unique_rows(v, expected_envelope_v)


def test_unit_cubes():
    n = 5
    v_repr = tools.unit_cube_v_repr(n)
    h_repr = tools.unit_cube_h_repr(n)

    h_repr_lin = np.empty((0, h_repr.shape[1]))
    v_repr_lin = np.empty((0, v_repr.shape[1]))

    v_repr_recovered = tools.h_to_v(h_repr, h_repr_lin)
    h_repr_recovered = tools.v_to_h(v_repr, v_repr_lin)
    assert tools.same_unique_rows(v_repr, v_repr_recovered)
    assert tools.same_unique_rows(h_repr, h_repr_recovered)


def test_is_h_empty():
    dim = 5

    h_ineq = tools.unit_cube_h_repr(dim)
    h_lin = np.array([[-2, 1, 0, 0, 0, 0]])
    assert tools.is_h_form_empty(h_ineq, h_lin)

    h_ineq = tools.unit_cube_h_repr(dim)
    h_lin = np.array([[-.2, 1, 0, 0, 0, 0]])
    assert not tools.is_h_form_empty(h_ineq, h_lin)


def test_prototype_getter1():
    dim = 5
    h_ineq = tools.unit_cube_h_repr(dim)
    h_lin = np.empty((0, dim))
    p = 1
    objective_sense = cvxpy.Minimize
    fall_back_to_vacuous_criterion = False
    prototype1 = tools.build_prototype_from_h_form(h_ineq,
                                                   h_lin,
                                                   objective_sense,
                                                   p,
                                                   fall_back_to_vacuous_criterion)

    np.testing.assert_allclose(1, prototype1[0])
    np.testing.assert_allclose(0, prototype1[1:], atol=1e-9)
    #
    # objective_sense = cvxpy.Maximize
    # prototype2 = tools.build_prototype_from_h_form(h_ineq, h_lin, objective_sense, p, fall_back_to_vacuous_criterion)


def test_inner_box():
    v_repr = np.array([[1, 0, -1],
                       [1, 0, +1],
                       [1, -1, 0],
                       [1, +1, 0]])
    v_lin = np.empty((0, v_repr.shape[1]))
    h_ineq = tools.v_to_h(v_repr, v_lin)
    lower, upper = tools.compute_maximum_volume_inner_box(h_ineq)
    np.testing.assert_allclose(lower, np.array([[-0.5], [-0.5]]))
    np.testing.assert_allclose(upper, np.array([[+0.5], [+0.5]]))


if __name__ == "__main__":
    # test_v_linear()
    test_coalesce1()
    # test_envelope()
    # test_unit_cubes()
    # test_is_h_empty()
    # test_prototype_getter1()
    # test_inner_box()
