import numpy as np
import cvxpy as cp
import time

# Used for checking purposes - too slow in practice
def cvxpy_l1_solver(P, q):
    m, n = P.shape
    q = q.ravel()
    var = cp.Variable(n)
    loss = cp.pnorm(cp.matmul(P, var) - q, p=1)
    problem = cp.Problem(cp.Minimize(loss))
    problem.solve(solver=cp.CVXOPT)
    var_val = var.value
    return var_val

# Faster l1 regression solver - This is the one we use.
# Returns the regression error and the regression solution.
def solve_l1_regression_v2(A, b, c_idx=None, debug=False):
    # print('start solving l1 regression for column {}'.format(c_idx))
    m, n = A.shape
    b = b.ravel() 
    x = cp.Variable(n) 
    t = cp.Variable(m) 

    # objective
    objective = cp.sum(t) 

    # contraints 
    constraints = [cp.matmul(A, x) - b <= t, -cp.matmul(A, x) + b <= t, t >= 0]

    # problem 
    problem = cp.Problem(cp.Minimize(objective), constraints)
    try:
        problem.solve(solver=cp.MOSEK, verbose=False)
    except:
        print("MOSEK FAILED")
        problem.solve(solver=cp.GLPK, verbose=False)

    if problem.status in ["infeasible", "unbounded"]:
        print('Problem status: {}'.format(problem.status))
        return None
    else:
        # print('c_idx: {}, sol: {}'.format(c_idx, x.value))
        return np.sum(np.abs(A @ x.value - b)), x.value

# This l1 regression solver uses the IRLS approach described at
# https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares
# Seems to require more time/iterations compared to the MOSEK-based solver.
def solve_l1_regression_irls(A, b, num_iterations=1000, delta=1e-18, c_idx=None, debug=False):
    n, _ = A.shape
    weights = np.ones(n)
    beta = None
    for t in range(num_iterations):
        # Solve for beta
        diag_weights = np.diag(weights)
        beta = np.linalg.pinv(A.T @ diag_weights @ A) @ A.T @ diag_weights @ b

        # Update weights
        entrywise_error = np.abs(A @ beta - b)
        entrywise_error = np.maximum(entrywise_error, delta)
        weights = 1/entrywise_error
        weights = weights.reshape((n,))
        # print('AAAAA', t)
    
    error = np.sum(np.abs(A @ beta - b))
    return error, beta


if __name__ == '__main__':
    m = 500
    n = 500
    P = np.random.randn(m, n)
    q = np.random.randn(m).reshape(m, 1)
    print(P.shape, q.shape)
    error, x = solve_l1_regression_v2(P, q, debug=True)
    print('l1 solver: ', x)

    """
    # Compare to cvxpy
    x_cvxpy = cvxpy_l1_solver(P, q)
    print('cvxpy solver: ', x_cvxpy)
    print('[Answer Comparison: ]', np.abs(x - x_cvxpy <= 1e-6).all())
    """

    # Compare to IRLS
    error_irls, x_irls = solve_l1_regression_irls(P, q, num_iterations=100)
    print('l1 solver IRLS: ', x_irls)

    print("==============")
    print("IRLS Error: ", error_irls)
    print("LP Error: ", error)