import itertools
import logging
import functools
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from collections import namedtuple

import numpy as np
import scipy.spatial
import scipy.linalg
import scipy.sparse

import cdd
import cvxpy


Polytope = namedtuple('Polytope', ['h', 'v'])
HRepresentation = namedtuple("HRepresentation", ["inequality", "linear", "is_empty"])
VRepresentation = namedtuple("VRepresentation", ["vertices", "is_empty"])
Region = List[Polytope]


logger = logging.getLogger(__name__)

# np.set_printoptions(linewidth=1000)


@functools.lru_cache(maxsize=16)
def powerset(i: Iterable) -> List[tuple]:
    if len(i) >= 22:
        warnings.warn("Computing the power set of collections this large can take significant amounts of time")
    s = list(i)
    chained = itertools.chain.from_iterable(
        itertools.combinations(s, r) for r in range(len(s) + 1)
    )
    return list(chained)


def onehot_encode(q: np.ndarray) -> np.ndarray:
    dim = q.shape[1]
    q_argmaxes = np.argmax(q, axis=1)
    nr = len(q_argmaxes)
    ohed = np.zeros((nr, dim))
    ohed[np.arange(nr), q_argmaxes] = 1
    return ohed


def h_to_v(h: np.ndarray,
           h_lin: np.ndarray,
           is_rational: bool = False) -> np.ndarray:
    assert h is not None
    assert h_lin is not None

    nc = h.shape[1]
    assert h_lin.shape[1] == nc
    assert not (h.shape[0] == 0 and h_lin.shape[0] == 0), "Check this case"
    # assert h.shape[0] > 0 or h_lin.shape[0] > 0, "Zero-row matrix will be rejected by cdd"

    if h.shape[0] == 0:
        warnings.warn("To avoid linearly dependent vertices I am appending a (meaningless, given the equality constraints) inequality constraint")
        warnings.warn("See: https://github.com/cddlib/cddlib/blob/master/doc/cddlibman.tex#L245")
        h = h_lin  # h [1, x] >= o if h [1, x] == 0, so the only purpose of this is to fool cdd into not returning
        # linear vertices (see checks below)

    if is_rational:
        f_cdd = cdd.Matrix(h.astype(str), number_type="fraction")
    else:
        f_cdd = cdd.Matrix(h, number_type="float")
    f_cdd.rep_type = cdd.RepType.INEQUALITY
    assert 0 == len(f_cdd.lin_set), "Unexpected colinearity, all linearities should be in h_lin"
    if h_lin.shape[0] > 0:
        f_cdd.extend(h_lin.astype(str), linear=True)

    p = cdd.Polyhedron(f_cdd)
    g = p.get_generators()

    if is_rational:
        lin_set = g.lin_set
        g = cdd.Matrix(g, number_type="float")
        g.lin_set = lin_set
    v = np.array(g)

    assert 0 == v.size or np.min(v[:, 0]) >= -1e-7, "Not yet wanting to handle this case (could in principle, but defer for now)"

    lin_set_list = list(g.lin_set)
    if len(lin_set_list) > 0:
        # from https://github.com/cddlib/cddlib/blob/master/doc/cddlibman.tex#L245:
        # [Linearity]
        #   This means for each such a ray $r_k$,
        #   the line generated by $r_k$ is in the polyhedron,
        #   and for each such a vertex $v_k$, its coefficient is no longer nonnegative
        #   but still the coefficients for all $v_i$'s must sum up to one.
        #   It is highly unlikely that one needs to
        #   use linearity for vertex generators, and it is defined mostly
        #   for formality.

        # ind = np.arange(v.shape[0])
        # is_lin_set = np.in1d(ind, lin_set_list)
        # is_ray = (0 == v[:, 0])
        # is_origin_vertex = np.logical_and(1 == v[:, 0], np.all(0 == v[:, 1:], axis=1))

        max_allowable_linear_vertex_weight = 1e-9
        max_linear_vertex_weight = np.max(np.abs(v[lin_set_list, 0]))
        assert np.all(max_linear_vertex_weight < max_allowable_linear_vertex_weight), "Not yet handling linear vertices, only rays"
        v = np.vstack([v, -1 * v[lin_set_list, :]])
    if 0 == v.size:
        v = np.empty((0, h.shape[1]), dtype=v.dtype)
    return v


def get_point_from_v_repr(v_repr: np.ndarray) -> np.ndarray:
    is_vertex = 1 == v_repr[:, 0]
    vertices_2d = np.atleast_2d(v_repr[is_vertex, 1:])
    vertex_piece = np.mean(vertices_2d, axis=0)
    ray_piece = np.zeros(vertex_piece.shape)
    point = vec(vertex_piece + ray_piece)
    return point


def v_to_h(v: np.ndarray,
           v_lin: Optional[np.ndarray],
           is_rational: bool = False) -> np.ndarray:
    assert v.shape[0] > 0, "Zero-row matrix will be rejected by cdd"
    if is_rational:
        f_cdd = cdd.Matrix(v.astype(str))
    else:
        f_cdd = cdd.Matrix(v, number_type="float")
    f_cdd.rep_type = cdd.RepType.GENERATOR
    assert 0 == len(f_cdd.lin_set), "Unexpected colinearity"

    if (v_lin is not None) and (v_lin.shape[0] > 0):
        f_cdd.extend(v_lin.astype(str), linear=True)

    p = cdd.Polyhedron(f_cdd)
    i = p.get_inequalities()

    if is_rational:
        lin_set = i.lin_set
        i = cdd.Matrix(i, number_type="float")
        i.lin_set = lin_set

    h = np.array(i)
    lin_set_list = list(i.lin_set)
    if len(lin_set_list) > 0:
        h = np.vstack([h, -1 * h[lin_set_list, :]])

    if 0 == h.size:
        h = np.empty((0, v.shape[1]), dtype=v.dtype)
    return h


def canonicalize_h_form(h: np.ndarray) -> np.ndarray:
    """
    NB. This works strickly with the inequality form.
    :param h:
    :return:
    """
    assert h.shape[0] > 0, "Zero-row matrix will be rejected by cdd"
    f_cdd = cdd.Matrix(h.astype(str))
    f_cdd.rep_type = cdd.RepType.INEQUALITY
    assert 0 == len(f_cdd.lin_set), "Unexpected colinearity"

    f_cdd.canonicalize()
    lin_set = f_cdd.lin_set
    f_cdd = cdd.Matrix(f_cdd, number_type="float")
    f_cdd.lin_set = lin_set

    h = np.array(f_cdd)
    lin_set_list = list(f_cdd.lin_set)
    if len(lin_set_list) > 0:
        h = np.vstack([h, -1 * h[lin_set_list, :]])
    return h


def canonicalize_v_form(v: np.ndarray) -> np.ndarray:
    assert v.shape[0] > 0, "Zero-row matrix will be rejected by cdd"
    f_cdd = cdd.Matrix(v.astype(str))
    f_cdd.rep_type = cdd.RepType.GENERATOR
    assert 0 == len(f_cdd.lin_set), "Unexpected colinearity"

    f_cdd.canonicalize()
    lin_set = f_cdd.lin_set
    f_cdd = cdd.Matrix(f_cdd, number_type="float")
    f_cdd.lin_set = lin_set

    v = np.array(f_cdd)
    lin_set_list = list(f_cdd.lin_set)
    if len(lin_set_list) > 0:
        v = np.vstack([v, -1 * v[lin_set_list, :]])
    return v


def vec(x: np.ndarray) -> np.ndarray:
    return np.reshape(x, (-1, 1))


def eliminate_sign_repeated_rows(m: np.ndarray) -> np.ndarray:
    """
    Drop rows from a matrix where -1 * that row is also in the matrix.
    The first row will be kept, the second one will be dropped.
    """
    dm = scipy.spatial.distance_matrix(m, -1 * m)
    is_upper = np.triu_indices(dm.shape[0])
    dm[is_upper] = np.nan
    is_drop = np.any(0 == dm, axis=1)  # todo: may need to be generalised
    return m[~is_drop, :]


def _prepend_ones(x: np.ndarray) -> np.ndarray:
    to_prepend = np.ones((x.shape[0], 1))
    return np.hstack([to_prepend, x])


def _prepend_zeros(x: np.ndarray) -> np.ndarray:
    to_prepend = np.zeros((x.shape[0], 1))
    return np.hstack([to_prepend, x])


def points_in_polytope(
    points: np.ndarray,
    v_repr: np.ndarray,
    is_rational: bool = False,
    tolerance: float = 0,
) -> np.ndarray:
    # There are faster algorithm, cf,:
    # https://en.wikipedia.org/wiki/Point_in_polygon
    num_points, points_dim = points.shape
    if 0 == v_repr.size:
        is_in_polytope = np.full((num_points, 1), False)
    else:
        h_repr = v_to_h(v_repr, None, is_rational)
        num_ineq, _ = h_repr.shape
        assert _ == points_dim + 1
        points_augmented = np.hstack([np.ones((num_points, 1)), points])
        ineq_holds = (h_repr @ points_augmented.T).T >= tolerance
        is_in_polytope = np.all(ineq_holds, axis=1, keepdims=True)
    return is_in_polytope


@functools.lru_cache(maxsize=16)
def rn_v_repr(n: np.ndarray) -> np.ndarray:
    eyen = np.eye(n)
    rays = np.vstack([+1 * eyen, -1 * eyen])
    v_repr = np.hstack([np.zeros((2 * n, 1)), rays])
    return v_repr


@functools.lru_cache(maxsize=16)
def _gen_all_01_rows(n: int) -> np.ndarray:
    index = tuple(range(n))
    index_powerset = powerset(index)
    r, c = zip(*[([i] * len(x), x) for i, x in enumerate(index_powerset)])

    i = list(itertools.chain.from_iterable(r))
    j = list(itertools.chain.from_iterable(c))
    ones = [1] * len(i)

    uc = scipy.sparse.coo_matrix((ones, (i, j)), dtype=np.float64)
    all_01_rows = uc.toarray()
    return all_01_rows


@functools.lru_cache(maxsize=16)
def unit_cube_v_repr(n: int) -> np.ndarray:
    all_01_rows = _gen_all_01_rows(n)
    v_repr = np.hstack([np.ones((all_01_rows.shape[0], 1)), all_01_rows])
    return v_repr


@functools.lru_cache(maxsize=16)
def unit_cube_h_repr(n: int) -> np.ndarray:
    is_less_than1 = np.hstack([np.ones((n, 1)), -1 * np.eye(n)])
    is_greater_than0 = np.hstack(([np.zeros((n, 1)), np.eye(n)]))
    h_repr = np.vstack([is_less_than1, is_greater_than0])
    return h_repr


@functools.lru_cache(maxsize=16)
def nth_canonical_basis(n: int, dim: int) -> np.ndarray:
    return vec((n == np.arange(dim)).astype(float))


def distance_between_h_reprs(
    h_ineq1: np.ndarray, h_ineq2: np.ndarray, p: Union[str, float]
) -> Tuple[float, Tuple[np.ndarray, np.ndarray]]:

    _, nc1 = h_ineq1.shape
    assert _ > 0, "Not defined for empty set first args"
    _, dim = h_ineq2.shape
    assert _ > 0, "Not defined for empty set second args"
    assert nc1 == dim, "Dimension mismatch"

    p1 = cvxpy.Variable(dim)
    p2 = cvxpy.Variable(dim)

    iota0 = vec(0 == np.arange(dim)).astype(float)

    constraints = [h_ineq1 @ p1 >= 0,
                   h_ineq2 @ p2 >= 0,
                   iota0.T @ p1 == 1,
                   iota0.T @ p2 == 1]

    # https://www.cvxpy.org/api_reference/cvxpy.atoms.other_atoms.html#pnorm-func
    objective = cvxpy.Minimize(cvxpy.pnorm(p1 - p2, p=p))
    prob = cvxpy.Problem(objective, constraints)

    try:
        solver = cvxpy.MOSEK
        prob.solve(solver=solver)
    except Exception as err:
        h1_lin = np.empty((0, h_ineq1.shape[1]))
        h2_lin = np.empty((0, h_ineq2.shape[1]))

        v1 = h_to_v(h_ineq1, h1_lin)
        v2 = h_to_v(h_ineq2, h2_lin)

        v_list = [v1, v2]
        import plotting

        fig, ax = plotting.list_of_convex_hull_plot_simple(v_list)
        ax.scatter(v1[0, 1], v1[0, 2])

    distance = prob.value
    assert distance is not None
    opt_points = (p1.value, p2.value)
    return distance, opt_points


def find_point_multipliers(
    point: np.ndarray, v_repr: np.ndarray, w: np.ndarray, b: np.ndarray
) -> np.ndarray:
    """ this function finds the minimum norm Lagrange
    multipliers that deliver a point.
    """
    n, m = w.shape
    vertices = v_repr[:, 1:]
    v, _ = vertices.shape
    assert _ == n
    assert (1, m) == point.shape

    is_ray = 0 == v_repr[:, 0]
    is_pol = ~is_ray

    layer1 = relu((w @ point.T) + b).T
    l = cvxpy.Variable(v)
    eye_v = np.eye(v)

    constraints = [0 <= l, vertices.T @ l == layer1.flatten()]
    if np.any(is_pol):
        ipf = vec(is_pol).T
        constraints = constraints + [ipf @ l == 1]

    objective = cvxpy.Minimize((1 / 2) * cvxpy.quad_form(l, eye_v))

    prob = cvxpy.Problem(objective, constraints)
    prob.solve()

    # assert np.abs(prob.objective.value) < 1e-6, "Apparently infeasible"
    point_multipliers = np.vstack(l.value)

    disc = vec(layer1) - vertices.T @ point_multipliers
    assert np.linalg.norm(disc, np.inf) < 1e-13
    return point_multipliers


def build_bounding_box_h_form(x_lower: np.ndarray,
                              x_upper: np.ndarray) -> np.ndarray:
    m, _ = x_lower.shape
    assert 1 == _
    assert m, 1 == x_upper.shape
    # x >= x_lower iff
    # -x_lower + I x >= 0
    # and
    # x <= x_upper iff
    # +x_upper - I x >= 0
    eyem = np.eye(m)
    h_lower = np.hstack([-1 * x_lower, +1 * eyem])
    h_upper = np.hstack([+1 * x_upper, -1 * eyem])

    h_lower = h_lower[np.isfinite(x_lower).flatten(), :]
    h_upper = h_upper[np.isfinite(x_upper).flatten(), :]

    h_bounds = np.vstack([h_lower, h_upper])
    return h_bounds


def apply_linear_transformation_to_v_repr(aa: np.ndarray,
                                          w: np.ndarray,
                                          b: np.ndarray) -> np.ndarray:
    """
    Given a set a = {Rl, l >= 0, p'l = 1}
    compute {wx + b : x in a}
    fix an x, since x in a, there exists an l such that

    x = sum_j r_j l_j

    thus,
    wx + b

    w (sum_j r_j l_j) + b =
    sum_j wr_j l_j =
    sum_j (wr_j +b)l_j
    """
    dim_in = aa.shape[1] - 1
    dim_out = len(b)
    assert dim_out, dim_in == w.shape

    if np.all(0 == aa[:, 0]):  # no vertices, just rays
        origin_vertex = np.hstack((np.eye(1), np.zeros((1, dim_in))))
        aa = np.vstack([origin_vertex, aa])
    is_ray = (0 == aa[:, 0])
    is_pol = ~is_ray

    vertices = aa[:, 1:]
    to_add = vec(is_pol) @ b.T
    transformed_vertices = (w @ vertices.T).T + to_add
    transformed_a = np.hstack((vec(aa[:, 0]), transformed_vertices))
    return transformed_a


def relu(x: np.ndarray) -> np.ndarray:
    return np.maximum(x, 0)


def build_polytope_where_nth_coordinate_is_greatest(n: int,
                                                    dim: int,
                                                    margin: float) -> np.ndarray:
    iota = nth_canonical_basis(n, dim)
    one = np.ones((dim, 1))

    c_full = one @ iota.T - np.eye(dim)
    class_c = c_full[1 != iota.flatten(), :]
    nr = class_c.shape[0]
    thresholds = np.ones((nr, 1)) * margin
    polytope_where_nth_coordinate_is_greatest = np.hstack((thresholds, class_c))
    return polytope_where_nth_coordinate_is_greatest


def build_prototype_from_v_form(v_gen: np.ndarray,
                                v_lin: np.ndarray) -> np.ndarray:
    assert 0 == v_lin.shape[0], "Not supporting nonempty v_lin yet"
    # uniform over vertices, zero rays:
    is_vertex = (0 != v_gen[:, 0])
    to_average = v_gen[is_vertex, :]
    prototype = np.mean(to_average, axis=0)
    return prototype


def build_prototype_from_h_form(h_ineq: np.ndarray,
                                h_lin: np.ndarray,
                                objective_sense: cvxpy.problems.objective,
                                p: float,
                                fall_back_to_vacuous_criterion: bool) -> np.ndarray:
    assert objective_sense == cvxpy.Minimize, "only minimize supported for now"
    # all_back_to_v_calc = True
    # https://www.cvxpy.org/tutorial/advanced/index.html#solve-method-options
    use_cvxopt = False

    if use_cvxopt:
        kwargs = {
            "solver": cvxpy.CVXOPT,
            "abstol": 1e-6
            }
    else:
        kwargs = {
            "solver": cvxpy.OSQP,
            'max_iter': 15000,
            'eps_abs': 1e-5,
            'eps_rel': 1e-5
        }

    verbose = False
    dim = h_ineq.shape[1]
    x = cvxpy.Variable(dim)

    iota0 = vec(0 == np.arange(dim)).astype(float)

    constraints = [iota0.T @ x == 1]
    if 0 < h_lin.shape[0]:
        constraints += [h_lin @ x == 0]

    if 0 < h_ineq.shape[0]:
        constraints += [h_ineq @ x >= 0]

    objective = objective_sense(cvxpy.norm(x, p))  # try to get non-zero-ish answers
    prob = cvxpy.Problem(objective, constraints)

    try:
        # value = prob.solve(verbose=verbose)
        # value = prob.solve(verbose=verbose, eps_abs=eps_abs, solver=solver)
        # value = prob.solve(verbose=verbose, eps_abs=eps_abs)
        value = prob.solve(verbose=verbose, **kwargs)
    except Exception as e:
        warnings.warn("Caught {} but continuing".format(e))
        if fall_back_to_vacuous_criterion:
            objective = cvxpy.Minimize(0)  # try to get non-zero-ish answers
            prob = cvxpy.Problem(objective, constraints)
            try:
                value = prob.solve(verbose=verbose)
            except Exception as e:
                warnings.warn("Caught {} but continuing".format(e))
                value = np.inf
        else:
            value = np.inf
    # prob.solve(verbose=True)
    is_empty = np.isinf(value)
    if is_empty:
        prototype = None
    else:
        prototype = vec(x.value)
    return prototype


def _minkowski_sum(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    assert 2 == len(x.shape)
    assert 2 == len(y.shape)

    rx, cx = x.shape
    ry, cy = y.shape

    assert cx == cy

    xlist = x.tolist()
    ylist = y.tolist()

    xyprod = itertools.product(xlist, ylist)
    xy = np.array([np.array(ix) + np.array(iy) for ix, iy in xyprod])
    total_sum = xy
    return total_sum


def intersect_v_reprs(v1: np.ndarray,
                      v2: np.ndarray) -> np.ndarray:
    assert (
        v1.shape[1] == v2.shape[1]
    ), "Can only intersect polytopes of the same dimension"
    h1 = v_to_h(v1, None)
    h2 = v_to_h(v2, None)
    h12 = np.vstack([h1, h2])
    h12_lin = np.empty((0, h12.shape[1]))
    intersected = h_to_v(h12, h12_lin)
    return intersected


def convex_union_v_reprs(v1: np.ndarray,
                         v2: np.ndarray) -> np.ndarray:
    v1_row_set = _matrix_to_row_set(v1)
    v2_row_set = _matrix_to_row_set(v2)
    v12_union = set.union(v1_row_set, v2_row_set)
    return _row_set_to_matrix(v12_union)


def _matrix_to_row_set(m: np.ndarray) -> set:
    return set(map(tuple, m))


def _row_set_to_matrix(rs: set) -> np.ndarray:
    return np.array(list(rs))


def same_unique_rows(m1: np.ndarray, m2: np.ndarray) -> bool:
    """ Check whether matrices have the same unique rows """
    assert 2 == m1.ndim and 2 == m2.ndim, "Only matrices supported"
    return _matrix_to_row_set(m1) == _matrix_to_row_set(m2)


def _symmetric_difference_matrix(v1: np.ndarray,
                                 v2: np.ndarray) -> np.ndarray:
    v1_row_set = _matrix_to_row_set(v1)
    v2_row_set = _matrix_to_row_set(v2)

    v12_symmetric_difference = set.symmetric_difference(v1_row_set, v2_row_set)
    sdm = _row_set_to_matrix(v12_symmetric_difference)
    return sdm


def is_inequality_valid(a: np.ndarray,
                        b: float,
                        poly_a: np.ndarray,
                        poly_b: np.ndarray,
                        tol: float = 0) -> bool:
    dim = len(a)
    x = cvxpy.Variable(dim)

    constraints = [poly_a @ x <= poly_b]
    objective = cvxpy.Maximize(a @ x)
    prob = cvxpy.Problem(objective, constraints)
    solver = cvxpy.GUROBI
    # solver = cvxpy.MOSEK
    prob.solve(solver=solver)

    value = prob.value
    tf = value <= b + tol
    return tf


def compute_valid_inequalities_for_poly(ineqs: np.ndarray,
                                        poly: np.ndarray) -> np.ndarray:
    tol = 1e-8
    diagnose = False
    # diagnose = True

    # plot_all = True
    plot_all = False
    if diagnose:
        ineqs_lin = np.empty((0, ineqs.shape[1]))
        poly_lin = np.empty((0, poly.shape[1]))

        v_ineqs = h_to_v(ineqs, ineqs_lin)
        v_poly = h_to_v(poly, poly_lin)

        import plotting
        lv = [v_ineqs, v_poly]
        fig, ax = plotting.list_of_convex_hull_plot_simple(lv)
        xlim = ax.get_xlim()
        plot_x = np.linspace(*xlim)

    num_ineqs = ineqs.shape[0]
    are_inequalities_valid_for_poly = np.full((num_ineqs,), False)

    poly_a = -1 * poly[:, 1:]
    poly_b = poly[:, 0]

    for idx in range(num_ineqs):
        # idx = 0
        a = -1 * ineqs[idx, 1:]
        b = ineqs[idx, 0]
        iv = is_inequality_valid(a, b, poly_a, poly_b, tol)
        are_inequalities_valid_for_poly[idx] = iv
        if diagnose:

            intercept = b / (a[1])
            slope = -1 * a[0] / a[1]
            plot_y = intercept + slope * plot_x
            if iv or plot_all:
                ax.plot(plot_x, plot_y)

    return are_inequalities_valid_for_poly


def _inner_optimisation(h1: np.ndarray,
                        h2: np.ndarray,
                        hbar: np.ndarray) -> float:
    dim = hbar.shape[1]

    assert 1, dim == h1.shape
    assert 1, dim == h2.shape
    xe = cvxpy.Variable(dim + 1)
    iota_0 = vec(0 == np.arange(dim + 1)).astype(float)
    iota_e = vec(dim == np.arange(dim + 1)).astype(float)

    c1 = np.hstack([-1 * h1, -1])
    c2 = np.hstack([-1 * h2, -1])
    cenv = np.hstack([hbar, np.zeros((hbar.shape[0], 1))])

    constraints = [c1 @ xe >= 0,
                   c2 @ xe >= 0,
                   cenv @ xe >= 0,
                   iota_0.T @ xe == 1]

    objective = cvxpy.Maximize(iota_e.T @ xe)
    prob = cvxpy.Problem(objective, constraints)

    # solver = cvxpy.MOSEK
    solver = cvxpy.GUROBI
    prob.solve(solver=solver)
    # prob.solve()
    eps_star = prob.value
    return eps_star


def envelope(h1: np.ndarray,
             h2: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    valid1for2 = compute_valid_inequalities_for_poly(ineqs=h1, poly=h2)
    valid2for1 = compute_valid_inequalities_for_poly(ineqs=h2, poly=h1)

    h1bar = h1[valid1for2, :]
    h2bar = h2[valid2for1, :]

    hbar = np.vstack([h1bar, h2bar])
    envelope_h = hbar

    do_check = True
    if do_check:
        h1_lin = np.zeros((0, h1.shape[1]))
        h2_lin = np.zeros((0, h2.shape[1]))
        v1 = h_to_v(h1, h1_lin, True)
        v2 = h_to_v(h2, h2_lin, True)
        envelope_h_lin = np.zeros((0, envelope_h.shape[1]))
        envelope_v = h_to_v(envelope_h, envelope_h_lin, True)
        v1_no_rays = np.all(v1[:, 0] == 1)
        v2_no_rays = np.all(v2[:, 0] == 1)
        envelope_v_no_rays = np.any(envelope_v[:, 0] == 1)

        if v1_no_rays and v2_no_rays:
            assert envelope_v_no_rays

    return envelope_h, valid1for2, valid2for1


def drop_rows_positive_proportional_to_another(r: np.ndarray,
                                               tol: float) -> bool:
    """
    r = np.array([[1, 2, 3],
                  [2, 4, 6],
                  [3, 6, 9]])
    """
    # np.set_printoptions(linewidth=1000)
    dotprods = r @ r.T
    norms = vec(np.diag(dotprods) ** .5)
    denominators = norms @ norms.T
    quotients = dotprods / denominators

    all_zero = np.all(np.abs(r) < tol, axis=1)

    triu_quotients = np.triu(quotients, k=+1)
    is_positively_proportional_to_another_row = np.any(triu_quotients >= 1 - tol, axis=1)
    dropped_r = r[~is_positively_proportional_to_another_row, :]
    return dropped_r


def union_h_reprs_if_convex(h1: np.ndarray, h2: np.ndarray) -> np.ndarray:
    env, valid1, valid2 = envelope(h1, h2)

    h1tilde = h1[~valid1, :]
    h2tilde = h2[~valid2, :]

    tol = .0001
    hbar = drop_rows_positive_proportional_to_another(env, tol)
    eps_thresh = 1e-13

    good = True
    n1tilde = h1tilde.shape[0]
    n2tilde = h2tilde.shape[0]
    for idx1 in range(n1tilde):
        # idx1 = 0
        if not good:
            continue
        h1 = h1tilde[idx1, :]
        for idx2 in range(n2tilde):
            # idx2 = 0
            h2 = h2tilde[idx2, :]
            eps_star = _inner_optimisation(h1, h2, hbar)
            if eps_star > eps_thresh:
                good = False
    if good:
        to_return = hbar
    else:
        to_return = None
    return to_return


def build_h_repr_of_point(point: np.ndarray,
                          is_rational: bool) -> np.ndarray:
    point1 = np.hstack([np.eye(1), vec(point).T])
    h_repr_of_point = v_to_h(point1, None, is_rational)
    return h_repr_of_point


def compute_data_distances_between_polytopes(
    ld: List[np.ndarray], x: np.ndarray, p: Union[str, float]
) -> np.ndarray:

    num_data = x.shape[0]
    num_polytopes = len(ld)
    data_polytopes_distances = np.full((num_data, num_polytopes), np.nan)
    log_every = 10
    for point_idx, point in enumerate(x):
        if point_idx % log_every == 0:
            logger.debug("{} / {}".format(point_idx, num_data))
        point_set_singleton = build_h_repr_of_point(point, True)
        for poly_idx, poly in enumerate(ld):
            d, (p1, p2) = distance_between_h_reprs(point_set_singleton, poly, p)
            data_polytopes_distances[point_idx, poly_idx] = d
    return data_polytopes_distances


def union_v_reprs_if_convex(
        v1: np.ndarray, v2: np.ndarray) -> np.ndarray:
    assert (
            v1.shape[1] == v2.shape[1]
    ), "Can only union polytopes of the same dimension"
    # todo: build a real algorithm, as in
    # http://cse.lab.imtlucca.it/~bemporad/publications/papers/compgeom-polyunion.pdf
    # rather than lean on the h-form algo

    h1 = v_to_h(v1, None)
    h2 = v_to_h(v2, None)
    h_unioned = union_h_reprs_if_convex(h1, h2)
    if h_unioned is None:
        v_unioned = None
    else:
        h_unioned_lin = np.empty((0, h_unioned.shape[1]))
        v_unioned = h_to_v(h_unioned, h_unioned_lin, True)
    return v_unioned


def is_union_of_h_reprs_convex(h1: np.ndarray, h2: np.ndarray) -> bool:
    cu = union_h_reprs_if_convex(h1, h2)
    tf = cu is not None
    return tf


def h_tuple_to_matrix(h_tuple: Tuple[np.ndarray, np.ndarray]) -> np.ndarray:
    h = h_tuple[0]
    h_lin = h_tuple[1]

    # ax == b iff
    # ax <= b
    # ax >= b iff -ax <= -b
    # h_lin_ineq =
    h_matrix = np.vstack([h, h_lin, -1 * h_lin])
    return h_matrix


def consolidate_list_of_h_forms(h_tuples: List[np.ndarray]) -> List[np.ndarray]:
    h_forms = [h_tuple_to_matrix(_) for _ in h_tuples]

    done = True
    max_passes = 10
    for p in range(max_passes):
        # num_h = len(h_forms)
        for idx1, h_form1 in enumerate(h_forms):
            # get the first h that will form a convex union in the remainder:
            for idx2 in range(idx1 + 1, len(h_forms)):
                print("{} -> {}".format(idx1, idx2))
                h_form2 = h_forms[idx2]

                assert (h_form1.shape != h_form2.shape) or \
                       not np.all(h_form1 == h_form2)
                tf = is_union_of_h_reprs_convex(h_form1, h_form2)
    consolidated_h_forms = h_forms
    return consolidated_h_forms


def is_h_form_empty(h_ineq: np.ndarray,
                    h_lin: np.ndarray) -> bool:
    dim = h_ineq.shape[1]
    p1 = cvxpy.Variable(dim)

    iota0 = vec(0 == np.arange(dim)).astype(float)

    constraints = [iota0.T @ p1 == 1]  # first column is of ones
    if 0 < h_lin.shape[0]:
        constraints += [h_lin @ p1 == 0]

    if 0 < h_ineq.shape[0]:
        constraints += [h_ineq @ p1 >= 0]
    # constraints = [h_ineq @ p1 >= 0,
    #                h_lin @ p1 == 0,
    #                iota0.T @ p1 == 1]

    objective = cvxpy.Minimize(0)

    prob = cvxpy.Problem(objective, constraints)

    # solver = cvxpy.MOSEK
    # solver = cvxpy.GUROBI
    # value = prob.solve(solver=solver)
    value = prob.solve()
    is_empty = np.isinf(value)
    return is_empty


def compute_hull_volume(vertices: np.ndarray) -> float:
    dim = vertices.shape[1]
    zero_rows = (0 == vertices.shape[0])
    if zero_rows or (np.linalg.matrix_rank(vertices) < dim):
        hull_volume = 0
    else:
        hull = scipy.spatial.ConvexHull(vertices)
        hull_volume = hull.volume
    return hull_volume


def compute_maximum_volume_inner_box(h_ineq: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    # Algorithm 1 from https://doi.org/10.1016/S0925-7721(03)00048-8
    a = -1 * h_ineq[:, 1:]
    b = np.vstack(h_ineq[:, 0])
    a_plus = np.clip(a, a_min=0, a_max=None)

    dim = a.shape[1]
    x = cvxpy.Variable(dim)
    y = cvxpy.Variable(dim)

    constraints = [a @ x + a_plus @ y <= b.flatten()]
    objective = cvxpy.Maximize(cvxpy.sum(cvxpy.log(y)))
    prob = cvxpy.Problem(objective, constraints)

    try:
        verbose = False
        # solver = cvxpy.MOSEK
        # solver = cvxpy.GUROBI
        # prob.solve(solver=solver, verbose=verbose)
        prob.solve(verbose=verbose)
    except Exception as err:
        print(err)

    prob_value = prob.value
    assert prob_value is not None
    x_value = x.value
    y_value = y.value

    lower = np.vstack(x_value)
    upper = np.vstack(x_value + y_value)
    # if False:
    #     h_lin = np.empty((0, h_ineq.shape[1]))
    #     v = h_to_v(h_ineq, h_lin)
    #
    #     assert np.all(1 == v[:, 0])
    #     vertices = v[:, 1:]
    #     hull_volume = compute_hull_volume(vertices)
    #     box_volume = np.prod(upper - lower)
    #     assert box_volume <= hull_volume
    #
    #     lower1 = np.vstack((np.eye(1), lower))
    #     upper1 = np.vstack((np.eye(1), upper))
    #
    #     violations_lower = h_ineq @ lower1
    #     violations_upper = h_ineq @ upper1
    return lower, upper


def v_repr_dim(v_repr: np.ndarray) -> int:
    assert np.all(np.in1d(v_repr[:, 0], [0, 1]))
    s = v_repr.shape[0]
    d = v_repr.shape[1] - 1
    dim = np.linalg.matrix_rank(v_repr.T) - 1
    return dim
