import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt

from scipy import io
from sklearn.preprocessing import MinMaxScaler

from skglm.utils import ST_vec


def primal_lasso(X, y, alpha, w):
    """Compute the primal objective of Lasso.

    Parameters
    ----------
    X : array, shape (n_samples, n_features)
        Design matrix.

    y : array, shape (n_samples,)
        Target vector.

    alpha : float
        Regularization parameter.

    w : array, shape (n_features)
        Coefficient vector.

    Returns
    -------
    p_obj : float
        The primal objective.
    """
    r = X @ w - y
    return 1/2 * r @ r + alpha * norm(w, ord=1)


def extrapolate(last_K_u, K, extrap_type="LPinf", s=100):
    """Construct an extrapolated point from past iterates.

    Parameters
    ----------
    last_K_u : array, shape (n_features, K)
        Array of past iterates.

    K : int
        Number of past iterates to use for extrapolation.

    extrap_type : str, (`LP` | `LPinf`), optional
        Extrapolation type.

    s : int
        LP steps (used if extra_type=`LP`).

    Returns
    -------
    w : array, shape (n_features,)
        Extrapolated vectors
    """
    q = K - 2

    V = np.diff(last_K_u, 1, axis=-1)

    V_k = V[:, 1:]
    v_k = V[:, -1]
    V_prev = V[:, :-1]

    # Compute coefficient
    VtV = V_prev.T @ V_prev

    try:
        c = np.linalg.solve(VtV, V_prev.T @ v_k)
    except np.linalg.LinAlgError:
        return v_k
    else:
        # Iteration matrix
        C = np.diag(np.ones(q-1), -1)
        C[:, -1] = c

        rho = norm(np.linalg.eigvals(C), ord=np.inf)

        if extrap_type == "LPinf":
            if rho < 1:
                tmp = np.eye(q) - C
                S = np.linalg.solve(tmp.T, C.T).T
            else:
                S = 0 * C
        else:
            if rho < 1:
                pC = np.linalg.matrix_power(C, s)
                tmp = np.eye(q) - C
                S = np.linalg.solve(tmp.T, (C - pC).T, ).T
            else:
                S = 0 * C

        return V_k @ S[:, -1]


def admm(X, y, alpha, gamma=2., max_iter=1000, tol=1e-5, check_gap_freq=50, a=0, K=6,
         use_accel=True, verbose=True):
    """Run Alternate Direction Method of multipliers optimization scheme for Lasso.

    Parameters
    ----------
    X : array, shape (n_samples, n_features)
        Design matrix.

    y : array, shape (n_samples,)
        Target vector.

    alpha : float
        Regularization parameter.

    gamma : float
        Augmented Lagrangian parameter.

    max_iter : int
        Maximum number of iterations.

    tol : float
        Tolerance.

    check_gap_freq : int
        Frequency for checking convergence.

    a : float
        Inertia parameter.

    K : int
        Number of past iterates to compute extrapolated point.

    use_accel : bool
        Use extrapolation.

    verbose : bool
        Verbosity.

    Returns
    -------
    w : array, shape (n_features,)
        Coefficient vector.
    """
    n_features = X.shape[1]
    residuals = []
    iterates = []

    # Acceleration variables
    last_K_u = np.zeros((n_features, K))

    # Optimization variables
    w = np.ones(n_features)  # Primal iterates
    z = np.ones(n_features)
    psi = np.ones(n_features)  # Dual iterates
    u = psi + gamma * w
    u_bar = u

    v = u - u

    # Pre-compute useful quantities
    XtX_scaled = X.T @ X / gamma
    Xty_scaled = X.T @ y / gamma
    L = np.linalg.cholesky(XtX_scaled + np.eye(n_features))
    U = L.T

    for iter in range(1, max_iter + 1):
        u_prev = u.copy()

        # Proximal step for datafit
        z = np.linalg.solve(U, np.linalg.solve(L, Xty_scaled + u_bar / gamma))
        psi = u_bar - gamma * z  # Dual update

        # Proximal step for pen
        w = ST_vec((u_bar - 2 * psi) / gamma, alpha / gamma)
        u = psi + gamma * w
        iterates.append(w)

        # Inertial step
        v = u - u_prev
        u_bar = u + a * v

        last_K_u = np.column_stack((last_K_u[:, 1:], u))

        if use_accel and iter % (K + 1) == 0:
            e = extrapolate(last_K_u, K)
            with np.errstate(divide="ignore"):
                # Removes warning for zero division at first iteration
                # Parameter safeguard - avoid numerical errors
                coeff = np.minimum(1., 1e5 / (iter**1.1 * norm(e)))
            u = u + coeff * e
            u_bar = u

        res = norm(v)
        residuals.append(res)

        if iter % check_gap_freq == 0:
            p_obj = primal_lasso(X, y, alpha, w)
            if verbose:
                print(f"iter {iter} :: residual {res:.5f} :: obj {p_obj:.4f}")

            if res < tol:
                break
    return w, residuals, iterates


if __name__ == "__main__":
    # Matrices can be downloaded at:
    # https://github.com/jliang993/A3DMM/tree/master/codes/data
    X = io.loadmat('covtype_sample.mat')["h"]
    y = io.loadmat('covtype_label.mat')["l"]

    y = np.ravel(y)

    scaler = MinMaxScaler(feature_range=(-1, 1))
    X = X.toarray()
    X = scaler.fit_transform(X)

    alpha = 1
    w, residuals, iterates = admm(X, y, alpha, tol=1e-6, use_accel=True,
                                  max_iter=50_000, check_gap_freq=10)
    print("#" * 25)
    w_no_accel, residuals_no_acc, iterates_no_acc = admm(X, y, alpha, tol=1e-6,
                                                         use_accel=False,
                                                         max_iter=50_000,
                                                         check_gap_freq=100)

    np.testing.assert_allclose(w, w_no_accel, rtol=1e-4)

    # Plotting
    norms_accel = list(map(lambda wc: np.log(norm(wc - w)), iterates))
    norms_no_accel = list(map(lambda wc: np.log(
        norm(wc - w_no_accel)), iterates_no_acc))
    plt.plot(norms_accel, label="Accelerated")
    plt.plot(norms_no_accel, label="No accel")
    plt.legend()
    plt.title("ADMM - ||x - x^*||")
    plt.show()
