import numpy as np
import scipy.linalg

import cdd

import tools

# np.set_printoptions(linewidth=1000)

# todo: Add some checks against LRS
# http://cgm.cs.mcgill.ca/~avis/C/lrs.html


def test_ray():
    pass


def test_fraction():
    mat1 = cdd.Matrix([[1, 2], [3, 4]])
    mat2 = f_cdd = cdd.Matrix([[1, 2], [3, 4]], number_type="float")
    mat3 = cdd.Matrix([[1.1, 2], [3, 4]])
    mat4 = cdd.Matrix([[1.1, 2], [3, 4]], number_type="fraction")
    mat5 = cdd.Matrix([["1.1", 2], [3, 4]], number_type="fraction")

    h1 = np.array([[1, 2], [3, 4]])
    h2 = np.random.rand(2, 2)

    mat6 = cdd.Matrix(h2.astype(str))

    print(mat4)
    print(mat5)
    print(mat6)


def test_canonicalize():
    v_minimal = np.array([[1, -1, -1],
                          [1, +1, -1],
                          [1, -1, +1],
                          [1, +1, +1]])

    redundant = np.array([[+1, 0, 0]])
    v = np.vstack([v_minimal, redundant])

    f_cdd = cdd.Matrix(v)
    f_cdd.rep_type = cdd.RepType.GENERATOR
    f_cdd.canonicalize()
    v_canonical = np.array(f_cdd)
    assert tools.same_unique_rows(v_canonical, v_minimal)


def test_empty_size():
    n = 4
    uc_h = tools.unit_cube_h_repr(n)

    a = tools.nth_canonical_basis(0, dim=n).T
    b = np.array([[-1]])
    to_append = np.hstack([b, -a])
    empty_h = np.vstack([to_append, uc_h])
    empty_h_lin = np.empty((0, empty_h.shape[1]))
    empty_v = tools.h_to_v(empty_h, empty_h_lin)
    assert 0 == empty_v.size


def _get_h_that_leads_to_linear_v() -> np.ndarray:
    h_ineq = np.array(
        [
            [-2.66543686e-01, +1.73966646e00],
            [-1.94172479e-01, -5.42370081e-01],
            [-2.95980266e00, -4.89011645e00],
            [+1.33176032e00, +3.19460511e00],
            [-1.65847051e-02, -4.80987364e-05],
            [+2.67734341e00, +5.13919830e00],
            [-4.13650423e-02, -2.00232193e-01],
            [-2.65334475e00, -4.81699085e00],
            [-2.86459918e00, -5.47769046e00],
            [-3.33203644e00, -6.57856417e00],
            [-2.83309763e00, -5.30819488e00],
        ]
    ).T
    return h_ineq


def _get_linear_v_data():
    is_rational = True
    h_ineq = _get_h_that_leads_to_linear_v()
    assert h_ineq.shape[0] > 0, "Zero-row matrix will be rejected by cdd"
    if is_rational:
        f_cdd = cdd.Matrix(h_ineq.astype(str))
    else:
        f_cdd = cdd.Matrix(h_ineq, number_type="float")

    f_cdd.rep_type = cdd.RepType.INEQUALITY
    assert 0 == len(f_cdd.lin_set), "Unexpected colinearity"

    p = cdd.Polyhedron(f_cdd)
    g = p.get_generators()

    if is_rational:
        lin_set = g.lin_set
        g = cdd.Matrix(g, number_type="float")
        g.lin_set = lin_set
    v_full = np.array(g)

    inds = np.arange(v_full.shape[0])
    lin_set_list = list(g.lin_set)
    is_lin_row = np.in1d(inds, lin_set_list)  # np.array(list(g.lin_set))

    v = v_full[~is_lin_row, :]
    v_lin = v_full[is_lin_row, :]
    return v, v_lin


def _row_normalise(h: np.ndarray) -> np.ndarray:
    # H representations are invariante to a constant positive scaling
    return h / tools.vec(np.linalg.norm(h, axis=1))


def test_v_eq_high_dimension():
    v, v_lin = _get_linear_v_data()

    old_v = np.vstack([v, -1 * v_lin, v_lin])
    old_h_raw = tools.v_to_h(old_v, None)
    old_h = tools.canonicalize_h_form(old_h_raw)

    act_h = _get_h_that_leads_to_linear_v()

    decimals = 12
    act_h_n = np.around(_row_normalise(act_h), decimals)
    old_h_n = np.around(_row_normalise(old_h), decimals)

    assert tools.same_unique_rows(act_h_n, old_h_n)

    # np.set_printoptions(linewidth=1000)
    new_h_raw = tools.v_to_h(v, v_lin, True)

    new_h = tools.canonicalize_h_form(new_h_raw)
    new_h_n = np.around(_row_normalise(new_h), decimals)
    assert tools.same_unique_rows(act_h_n, new_h_n)


def _get_h_linear_data():
    h_lin = np.array([[1., -0., -0., -0., -1., -1., -1., -1., -1., -1., -1., -1.],
                       [0., -0., -0., -0., -0., 1., -0., -0., 1., 1., -0., 1.],
                       [-0., -1., -0., -0., -0., -0., -0., 1., -0., 1., 1., 1.],
                       [-0., -0., -1., -0., -0., -0., 1., -0., 1., -0., 1., 1.]])
    h = np.array([[0., -0., -0., -0., 1., 0., 0., 0., 0., 0., 0., 0.],
                   [0., -0., -0., -0., 0., 1., 0., 0., 0., 0., 0., 0.],
                   [0., -0., -0., -0., 0., 0., 1., 0., 0., 0., 0., 0.],
                   [0., -0., -0., -0., 0., 0., 0., 1., 0., 0., 0., 0.],
                   [0., -0., -0., -0., 0., 0., 0., 0., 1., 0., 0., 0.],
                   [0., -0., -0., -0., 0., 0., 0., 0., 0., 1., 0., 0.],
                   [0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 1., 0.],
                   [0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 1.],
                   [-0., 1., 0., 0., -0., -0., -0., -0., -0., -0., -0., -0.],
                   [-0., 0., 1., 0., -0., -0., -0., -0., -0., -0., -0., -0.],
                   [-0., 0., 0., 1., -0., -0., -0., -0., -0., -0., -0., -0.],
                   [-0., -0., -0., -1., -0., -0., -0., -0., -0., -0., -0., -0.],
                   [0., 1., 0., 0., -0., -0., -0., -0., -0., -0., -0., -0.],
                   [0., 0., 1., 0., -0., -0., -0., -0., -0., -0., -0., -0.]])
    return h, h_lin


def test_h_linear():
    h, h_lin = _get_h_linear_data()

    h_stacked = np.vstack([h, h_lin, -1 * h_lin])
    v_old = tools.h_to_v(h, h_lin, True)

    h_lin_stacked = np.empty((0, h_stacked.shape[1]))
    v_new = tools.h_to_v(h_stacked, h_lin_stacked, True)
    assert tools.same_unique_rows(v_old, v_new)


def test_purely_linear_h_simple():
    # h_lin = np.array([[1, -1, +1]])
    h_lin = np.array([[1, -1, +1]])

    # h_lin = np.array([[0, -1, +1]])
    h = np.empty((0, h_lin.shape[1]))
    # h = np.array([[1, 0, -1]])

    expected_v = np.array([[1., 1., 0.],
                           [0., 1., 1.],
                           [-0., -1., -1.]])

    v = tools.h_to_v(h, h_lin, True)
    tools.same_unique_rows(v, expected_v)


def test_analytical_v_repr_of_linear_mapping_of_v_repr():
    w = np.array([[0, 1, 1],
                  [1, 0, 1]])
    b = np.array([[0], [0]])

    a = tools.unit_cube_v_repr(3)
    # wab = _apply_linear_transformation_to_v_repr(a, w, b)
    # wab_canonical = tools.canonicalize_v_form(wab)
    #
    # if False:
    #     np.set_printoptions(linewidth=1000)
    #     import plotting
    #     plotting.convex_hull_plot_simple(wab_canonical)

    # a = rn_v_repr
    # b = centre
    # w = c_perp

    # v_new = _apply_linear_transformation_to_v_repr(a, w, b)
    #
    # v_new_canonical = tools.canonicalize_v_form(v_new)
    # v_old_canonical = tools.canonicalize_v_form(v_old)


def _analytical_v_repr_of_h_lin(h_lin: np.ndarray) -> np.ndarray:
    """
    Generate the V representation of {x : Ax = b}
    """
    c = -1 * h_lin[:, 1:]
    d = tools.vec(h_lin[:, 0])
    c_perp = scipy.linalg.null_space(c)
    c_pinv = np.linalg.pinv(c)
    d_in, d_out = c_perp.shape

    centre = c_pinv @ d
    rn_v_repr = tools.rn_v_repr(d_out)
    rn_v_repr_rays = rn_v_repr[:, 1:]

    origin = np.hstack((np.eye(1), np.zeros((1, d_in))))
    crn_v_repr = np.hstack((tools.vec(rn_v_repr[:, 0]), (c_perp @ rn_v_repr_rays.T).T))
    crn0_v_repr = np.vstack((origin, crn_v_repr))

    to_add = np.hstack((np.zeros((crn0_v_repr.shape[0], 1)), tools.vec(crn0_v_repr[:, 0]) @ centre.T))
    linear_image_v_repr = crn0_v_repr + to_add
    # linear_image_v_repr_canonical = tools.canonicalize_v_form(linear_image_v_repr)
    # if False:
    #     np.set_printoptions(linewidth=1000)

    return linear_image_v_repr


def test_purely_linear_h():
    h, h_lin = _get_h_linear_data()
    h = np.empty((0, h_lin.shape[1]))

    v_old = tools.h_to_v(h, h_lin, True)
    v_new = _analytical_v_repr_of_h_lin(h_lin)
    decimals = 14

    v_new_r = np.around(v_new, decimals)
    v_old_r = np.around(v_old, decimals)

    print(v_new_r.shape)
    print(v_old_r.shape)

    v_to_h_new = tools.v_to_h(v_new, None, True)
    v_to_h_old = tools.v_to_h(v_old, None, True)

    v_to_h_new_r = np.around(v_to_h_new, decimals)
    v_to_h_old_r = np.around(v_to_h_old, decimals)
    #
    print(v_to_h_new_r - v_to_h_old_r)

    # v_to_h_new_rc = tools.canonicalize_v_form(v_to_h_new_r)
    # v_to_h_old_r = np.around(v_to_h_old, decimals)

    assert tools.same_unique_rows(v_to_h_new_r, v_to_h_old_r)



def test_v_linearity():
    h_lin = np.array([[+1, +1, -1, 0]])
    h = np.empty((0, h_lin.shape[1]))
    v = tools.h_to_v(h, h_lin)


if __name__ == "__main__":
    # test_ray()
    # test_empty_size()
    # test_h_linear()
    test_purely_linear_h()
    # test_empty_size()
    # test_v_eq_high_dimension()
