"""
Routines for removing redundant (linearly dependent) equations from linear
programming equality constraints.
"""
# Author: Matt Haberland

from __future__ import division, print_function, absolute_import
import numpy as np
from scipy.linalg import svd
import scipy


def _row_count(A):
    """
    Counts the number of nonzeros in each row of input array A.
    Nonzeros are defined as any element with absolute value greater than
    tol = 1e-13. This value should probably be an input to the function.

    Parameters
    ----------
    A : 2-D array
        An array representing a matrix

    Returns
    -------
    rowcount : 1-D array
        Number of nonzeros in each row of A

    """
    tol = 1e-13
    return np.array((abs(A) > tol).sum(axis=1)).flatten()


def _get_densest(A, eligibleRows):
    """
    Returns the index of the densest row of A. Ignores rows that are not
    eligible for consideration.

    Parameters
    ----------
    A : 2-D array
        An array representing a matrix
    eligibleRows : 1-D logical array
        Values indicate whether the corresponding row of A is eligible
        to be considered

    Returns
    -------
    i_densest : int
        Index of the densest row in A eligible for consideration

    """
    rowCounts = _row_count(A)
    return np.argmax(rowCounts * eligibleRows)


def _remove_zero_rows(A):
    """
    Eliminates trivial equations from system of equations defined by Ax = b
   and identifies trivial infeasibilities

    Parameters
    ----------
    A : 2-D array
        An array representing the left-hand side of a system of equations
    b : 1-D array
        An array representing the right-hand side of a system of equations

    Returns
    -------
    A : 2-D array
        An array representing the left-hand side of a system of equations
    b : 1-D array
        An array representing the right-hand side of a system of equations
    status: int
        An integer indicating the status of the removal operation
        0: No infeasibility identified
        2: Trivially infeasible
    message : str
        A string descriptor of the exit status of the optimization.

    """
    status = 0
    message = ""
    i_zero = _row_count(A) == 0
    A = A[np.logical_not(i_zero), :]
    if not(np.allclose(b[i_zero], 0)):
        status = 2
        message = "There is a zero row in A_eq with a nonzero corresponding " \
                  "entry in b_eq. The problem is infeasible."
    b = b[np.logical_not(i_zero)]
    return A, b, status, message


def bg_update_dense(plu, perm_r, v, j):
    LU, p = plu

    u = scipy.linalg.solve_triangular(LU, v[perm_r], lower=True,
                                      unit_diagonal=True)
    LU[:j+1, j] = u[:j+1]
    l = u[j+1:]
    piv = LU[j, j]
    LU[j+1:, j] += (l/piv)
    return LU, p


def _remove_redundancy_dense(A, rhs):
    """
    Eliminates redundant equations from system of equations defined by Ax = b
    and identifies infeasibilities.

    Parameters
    ----------
    A : 2-D sparse matrix
        An matrix representing the left-hand side of a system of equations
    rhs : 1-D array
        An array representing the right-hand side of a system of equations

    Returns
    ----------
    A : 2-D sparse matrix
        A matrix representing the left-hand side of a system of equations
    rhs : 1-D array
        An array representing the right-hand side of a system of equations
    status: int
        An integer indicating the status of the system
        0: No infeasibility identified
        2: Trivially infeasible
    message : str
        A string descriptor of the exit status of the optimization.

    References
    ----------
    .. [2] Andersen, Erling D. "Finding all linearly dependent rows in
           large-scale linear programming." Optimization Methods and Software
           6.3 (1995): 219-227.

    """
    tolapiv = 1e-8
    tolprimal = 1e-8
    status = 0
    message = ""
    inconsistent = ("There is a linear combination of rows of A_eq that "
                    "results in zero, suggesting a redundant constraint. "
                    "However the same linear combination of b_eq is "
                    "nonzero, suggesting that the constraints conflict "
                    "and the problem is infeasible.")
    A, rhs, status, message = _remove_zero_rows(A, rhs)

    if status != 0:
        return A, rhs, status, message

    m, n = A.shape

    v = list(range(m))      # Artificial column indices.
    b = list(v)             # Basis column indices.
    # This is better as a list than a set because column order of basis matrix
    # needs to be consistent.
    k = set(range(m, m+n))  # Structural column indices.
    d = []                  # Indices of dependent rows
    lu = None
    perm_r = None

    A_orig = A
    A = np.hstack((np.eye(m), A))
    e = np.zeros(m)

    # Implements basic algorithm from [2]
    # Uses some of the suggested improvements (removing zero rows and
    # Bartels-Golub update idea).
    # Removing column singletons would be easy, but it is not as important
    # because the procedure is performed only on the equality constraint
    # matrix from the original problem - not on the canonical form matrix,
    # which would have many more column singletons due to slack variables
    # from the inequality constraints.
    # The thoughts on "crashing" the initial basis sound useful, but the
    # description of the procedure seems to assume a lot of familiarity with
    # the subject; it is not very explicit. I already went through enough
    # trouble getting the basic algorithm working, so I was not interested in
    # trying to decipher this, too. (Overall, the paper is fraught with
    # mistakes and ambiguities - which is strange, because the rest of
    # Andersen's papers are quite good.)

    B = A[:, b]
    for i in v:

        e[i] = 1
        if i > 0:
            e[i-1] = 0

        try:  # fails for i==0 and any time it gets ill-conditioned
            j = b[i-1]
            lu = bg_update_dense(lu, perm_r, A[:, j], i-1)
        except Exception:
            lu = scipy.linalg.lu_factor(B)
            LU, p = lu
            perm_r = list(range(m))
            for i1, i2 in enumerate(p):
                perm_r[i1], perm_r[i2] = perm_r[i2], perm_r[i1]

        pi = scipy.linalg.lu_solve(lu, e, trans=1)

        # not efficient, but this is not the time sink...
        js = np.array(list(k-set(b)))
        batch = 50

        # This is a tiny bit faster than looping over columns indivually,
        # like for j in js: if abs(A[:,j].transpose().dot(pi)) > tolapiv:
        for j_index in range(0, len(js), batch):
            j_indices = js[np.arange(j_index, min(j_index+batch, len(js)))]

            c = abs(A[:, j_indices].transpose().dot(pi))
            if (c > tolapiv).any():
                j = js[j_index + np.argmax(c)]  # very independent column
                B[:, i] = A[:, j]
                b[i] = j
                break
        else:
            bibar = pi.T.dot(rhs.reshape(-1, 1))
            bnorm = np.linalg.norm(rhs)
            if abs(bibar)/(1+bnorm) > tolprimal:  # inconsistent
                status = 2
                message = inconsistent
                return A_orig, rhs, status, message
            else:  # dependent
                d.append(i)

    keep = set(range(m))
    keep = list(keep - set(d))
    return A_orig[keep, :], rhs[keep], status, message


def _remove_redundancy_sparse(A, rhs):
    """
    Eliminates redundant equations from system of equations defined by Ax = b
    and identifies infeasibilities.

    Parameters
    ----------
    A : 2-D sparse matrix
        An matrix representing the left-hand side of a system of equations
    rhs : 1-D array
        An array representing the right-hand side of a system of equations

    Returns
    -------
    A : 2-D sparse matrix
        A matrix representing the left-hand side of a system of equations
    rhs : 1-D array
        An array representing the right-hand side of a system of equations
    status: int
        An integer indicating the status of the system
        0: No infeasibility identified
        2: Trivially infeasible
    message : str
        A string descriptor of the exit status of the optimization.

    References
    ----------
    .. [2] Andersen, Erling D. "Finding all linearly dependent rows in
           large-scale linear programming." Optimization Methods and Software
           6.3 (1995): 219-227.

    """

    tolapiv = 1e-8
    tolprimal = 1e-8
    status = 0
    message = ""
    inconsistent = ("There is a linear combination of rows of A_eq that "
                    "results in zero, suggesting a redundant constraint. "
                    "However the same linear combination of b_eq is "
                    "nonzero, suggesting that the constraints conflict "
                    "and the problem is infeasible.")
    A, rhs, status, message = _remove_zero_rows(A, rhs)

    if status != 0:
        return A, rhs, status, message

    m, n = A.shape

    v = list(range(m))      # Artificial column indices.
    b = list(v)             # Basis column indices.
    # This is better as a list than a set because column order of basis matrix
    # needs to be consistent.
    k = set(range(m, m+n))  # Structural column indices.
    d = []                  # Indices of dependent rows

    A_orig = A
    A = scipy.sparse.hstack((scipy.sparse.eye(m), A)).tocsc()
    e = np.zeros(m)

    # Implements basic algorithm from [2]
    # Uses only one of the suggested improvements (removing zero rows).
    # Removing column singletons would be easy, but it is not as important
    # because the procedure is performed only on the equality constraint
    # matrix from the original problem - not on the canonical form matrix,
    # which would have many more column singletons due to slack variables
    # from the inequality constraints.
    # The thoughts on "crashing" the initial basis sound useful, but the
    # description of the procedure seems to assume a lot of familiarity with
    # the subject; it is not very explicit. I already went through enough
    # trouble getting the basic algorithm working, so I was not interested in
    # trying to decipher this, too. (Overall, the paper is fraught with
    # mistakes and ambiguities - which is strange, because the rest of
    # Andersen's papers are quite good.)
    # I tried and tried and tried to improve performance using the
    # Bartels-Golub update. It works, but it's only practical if the LU
    # factorization can be specialized as described, and that is not possible
    # until the SciPy SuperLU interface permits control over column
    # permutation - see issue #7700.

    for i in v:
        B = A[:, b]

        e[i] = 1
        if i > 0:
            e[i-1] = 0

        pi = scipy.sparse.linalg.spsolve(B.transpose(), e).reshape(-1, 1)

        js = list(k-set(b))  # not efficient, but this is not the time sink...

        # Due to overhead, it tends to be faster (for problems tested) to
        # compute the full matrix-vector product rather than individual
        # vector-vector products (with the chance of terminating as soon
        # as any are nonzero). For very large matrices, it might be worth
        # it to compute, say, 100 or 1000 at a time and stop when a nonzero
        # is found.

        c = (np.abs(A[:, js].transpose().dot(pi)) > tolapiv).nonzero()[0]
        if len(c) > 0:  # independent
            j = js[c[0]]
            # in a previous commit, the previous line was changed to choose
            # index j corresponding with the maximum dot product.
            # While this avoided issues with almost
            # singular matrices, it slowed the routine in most NETLIB tests.
            # I think this is because these columns were denser than the
            # first column with nonzero dot product (c[0]).
            # It would be nice to have a heuristic that balances sparsity with
            # high dot product, but I don't think it's worth the time to
            # develop one right now. Bartels-Golub update is a much higher
            # priority.
            b[i] = j  # replace artificial column
        else:
            bibar = pi.T.dot(rhs.reshape(-1, 1))
            bnorm = np.linalg.norm(rhs)
            if abs(bibar)/(1 + bnorm) > tolprimal:
                status = 2
                message = inconsistent
                return A_orig, rhs, status, message
            else:  # dependent
                d.append(i)

    keep = set(range(m))
    keep = list(keep - set(d))
    return d, status, message


def _remove_redundancy(A):
    """
    Eliminates redundant equations from system of equations defined by Ax = b
    and identifies infeasibilities.

    Parameters
    ----------
    A : 2-D array
        An array representing the left-hand side of a system of equations
    b : 1-D array
        An array representing the right-hand side of a system of equations

    Returns
    -------
    A : 2-D array
        An array representing the left-hand side of a system of equations
    b : 1-D array
        An array representing the right-hand side of a system of equations
    status: int
        An integer indicating the status of the system
        0: No infeasibility identified
        2: Trivially infeasible
    message : str
        A string descriptor of the exit status of the optimization.

    References
    ----------
    .. [2] Andersen, Erling D. "Finding all linearly dependent rows in
           large-scale linear programming." Optimization Methods and Software
           6.3 (1995): 219-227.

    """

    # A, status, message = _remove_zero_rows(A)
    status = 0
    message = ""

    # if status != 0:
    #     return A, b, status, message

    A_ = A.copy()
    m, n = A_.shape
    eps = np.finfo(float).eps
    U, s, Vh = svd(A_)
    s_min = s[-1] if m <= n else 0
    tol = s.max() * max(A.shape) * eps
    d = []

    # this algorithm is faster than that of [2] when the nullspace is small
    # but it could probably be improvement by randomized algorithms and with
    # a sparse implementation.
    # it relies on repeated singular value decomposition to find linearly
    # dependent rows (as identified by columns of U that correspond with zero
    # singular values). Unfortunately, only one row can be removed per
    # decomposition (I tried otherwise; doing so can cause problems.)
    # It would be nice if we could do truncated SVD like sp.sparse.linalg.svds
    # but that function is unreliable at finding singular values near zero.
    # Finding max eigenvalue L of A A^T, then largest eigenvalue (and
    # associated eigenvector) of -A A^T + L I (I is identity) via power
    # iteration would also work in theory, but is only efficient if the
    # smallest nonzero eigenvalue of A A^T is close to the largest nonzero
    # eigenvalue.

    while abs(s_min) < tol:
        v = U[:, -1]  # TODO: return these so user can eliminate from problem?
        # rows need to be represented in significant amount
        eligibleRows = np.abs(v) > tol * 10e6
        # if not np.any(eligibleRows) or np.any(np.abs(v.dot(A)) > tol):
        #     status = 4
        #     message = ("Due to numerical issues, redundant equality "
        #                "constraints could not be removed automatically. "
        #                "Try providing your constraint matrices as sparse "
        #                "matrices to activate sparse presolve, try turning "
        #                "off redundancy removal, or try turning off presolve "
        #                "altogether.")
        #     break
        # if np.any(np.abs(v.dot(b)) > tol * 100):  # factor of 100 to fix 10038 and 10349
        #     status = 2
        #     message = ("There is a linear combination of rows of A_eq that "
        #                "results in zero, suggesting a redundant constraint. "
        #                "However the same linear combination of b_eq is "
        #                "nonzero, suggesting that the constraints conflict "
        #                "and the problem is infeasible.")
        #     break

        i_remove = _get_densest(A_, eligibleRows)
        A_ = np.delete(A_, i_remove, axis=0)
        # b = np.delete(b, i_remove)
        U, s, Vh = svd(A_)
        m, n = A_.shape
        s_min = s[-1] if m <= n else 0
        d.append(i_remove)

    return d, status, message
