
from sklearn.linear_model import LassoCV, ElasticNetCV, lasso_path, enet_path, OrthogonalMatchingPursuit
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize, MaxAbsScaler, StandardScaler
from sklearn.feature_selection import VarianceThreshold
from sklearn.pipeline import make_pipeline
from scipy import sparse
import numpy as np

def F(X, y, w, lam):
    return np.mean((y - X.dot(w))**2 + lam/2 * np.linalg.norm(w)**2)

def optimize(algorithm, rs, X, y, max_iter, mu, q, eta, lam, k, beta=0.5, batch_size=None, update_freq=None, plot_freq=None, zomax=None): 
    if sparse.issparse(X):
        scl = MaxAbsScaler()
    else:
        scl = StandardScaler()
    X = scl.fit_transform(X)
    w = np.zeros(X.shape[1], dtype=np.float64)
    print('Training...')
    hist, nizo, nht = [], [], []
    it_count = []
    loss_full = F(X, y, w, lam)
    hist.append(loss_full)
    it_count.append(0)
    nizo_count = 0
    nht_count = 0
    nizo.append(nizo_count)
    nht.append(nht_count)
    i = 1

    if algorithm == 'fgzoht':
        while (i < max_iter + 1) and (nizo_count < zomax):
            ghat = np.zeros(X.shape[1])
            loss = F(X, y, w, lam)
            nizo_count += X.shape[0]
            for j in range(q):
                u = rs.randn(X.shape[1])
                u /= np.linalg.norm(u)
                ghat += X.shape[1] * (F(X, y, w + mu * u, lam) - loss)/ mu * u
                nizo_count += X.shape[0]
            ghat /= q
            w -= eta * ghat
            w = hard_threshold(w, k)
            nht_count += 1
            if i % plot_freq == 0:
                loss_full = F(X, y, w, lam)
                hist.append(loss_full)
                it_count.append(i)
                nizo.append(nizo_count)
                nht.append(nht_count)
            i += 1


    if algorithm.startswith('szoht'):
        i = 1
        while (i < max_iter + 1) and (nizo_count < zomax):
            batch_idx = rs.randint(0, X.shape[0], size=batch_size)
            X_batch, y_batch = X[batch_idx], y[batch_idx]
            ghat = np.zeros(X.shape[1])
            loss = F(X_batch, y_batch, w, lam)
            nizo_count += batch_size
            for j in range(q):
                u = rs.randn(X.shape[1])
                u /= np.linalg.norm(u)
                ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - loss)/ mu * u
                nizo_count += batch_size
            ghat /= q
            w -= eta * ghat
            w = hard_threshold(w, k)
            nht_count += 1
            if i % plot_freq == 0:
                loss_full = F(X, y, w, lam)
                it_count.append(i)
                hist.append(loss_full)
                nizo.append(nizo_count)
                nht.append(nht_count)
            i+= 1


    elif algorithm.startswith('svrgzoht'):
        outer_its = 0
        total_its = 0
        while (outer_its <= max_iter // update_freq) and (nizo_count < zomax):
            inner_its = 0
            anchor = w + 0.
            full_ghat = np.zeros(X.shape[1]).astype(np.float64)
            full_loss_anchor = F(X, y, anchor, lam)
            nizo_count += X.shape[0]
            for j in range(q):
                u = rs.randn(X.shape[1]).astype(np.float64)
                u /= np.linalg.norm(u)
                full_ghat += X.shape[1] * (F(X, y, anchor + mu * u, lam) - full_loss_anchor)/ mu * u
                nizo_count += X.shape[0]
            full_ghat /= q
            while (inner_its < update_freq) and (total_its < max_iter) and (nizo_count < zomax): 
                batch_idx = rs.randint(X.shape[0], size=batch_size)
                X_batch, y_batch = X[batch_idx], y[batch_idx]
                batch_loss = F(X_batch, y_batch, w, lam)
                nizo_count += batch_size
                ghat = np.zeros(X.shape[1]).astype(np.float64)
                for j in range(q):
                    u = rs.randn(X.shape[1]).astype(np.float64)
                    u /= np.linalg.norm(u)
                    ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - batch_loss)/ mu * u
                    nizo_count += batch_size
                ghat /= q
                batch_loss_anchor = F(X_batch, y_batch, anchor, lam)
                nizo_count += X_batch.shape[0]
                ghat_anchor = np.zeros(X.shape[1], dtype=np.float64)
                for j in range(q):
                    u = rs.randn(X.shape[1]).astype(np.float64)
                    u /= np.linalg.norm(u)
                    ghat_anchor += X.shape[1] * (F(X_batch, y_batch, anchor + mu * u, lam) - batch_loss_anchor)/ mu * u
                    nizo_count += batch_size
                ghat_anchor /= q
                w -= eta * (ghat - ghat_anchor + full_ghat)
                w = hard_threshold(w, k)
                nht_count += 1
                inner_its += 1
                total_its += 1
                if total_its % plot_freq == 0:
                    loss_full = F(X, y, w, lam)
                    hist.append(loss_full)
                    it_count.append(total_its)
                    nizo.append(nizo_count)
                    nht.append(nht_count)
            outer_its += 1


    elif algorithm.startswith('b-svrgzoht'):
        # beta = 0.5
        b = beta
        outer_its = 0
        total_its = 0
        while (outer_its <= max_iter // update_freq) and (nizo_count < zomax):
            inner_its = 0
            anchor = w + 0.
            full_ghat = np.zeros(X.shape[1]).astype(np.float64)
            full_loss_anchor = F(X, y, anchor, lam)
            nizo_count += X.shape[0]
            for j in range(q):
                u = rs.randn(X.shape[1]).astype(np.float64)
                u /= np.linalg.norm(u)
                full_ghat += X.shape[1] * (F(X, y, anchor + mu * u, lam) - full_loss_anchor)/ mu * u
                nizo_count += X.shape[0]
            full_ghat /= q
            while (inner_its < update_freq) and (total_its < max_iter) and (nizo_count < zomax): 
                batch_idx = rs.randint(X.shape[0], size=batch_size)
                X_batch, y_batch = X[batch_idx], y[batch_idx]
                batch_loss = F(X_batch, y_batch, w, lam)
                nizo_count += batch_size
                ghat = np.zeros(X.shape[1]).astype(np.float64)
                for j in range(q):
                    u = rs.randn(X.shape[1]).astype(np.float64)
                    u /= np.linalg.norm(u)
                    ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - batch_loss)/ mu * u
                    nizo_count += batch_size
                ghat /= q
                batch_loss_anchor = F(X_batch, y_batch, anchor, lam)
                nizo_count += X_batch.shape[0]
                ghat_anchor = np.zeros(X.shape[1], dtype=np.float64)
                for j in range(q):
                    u = rs.randn(X.shape[1]).astype(np.float64)
                    u /= np.linalg.norm(u)
                    ghat_anchor += X.shape[1] * (F(X_batch, y_batch, anchor + mu * u, lam) - batch_loss_anchor)/ mu * u
                    nizo_count += batch_size
                ghat_anchor /= q
                w -= eta * ( b*(ghat - ghat_anchor) + full_ghat)
                w = hard_threshold(w, k)
                nht_count += 1
                inner_its += 1
                total_its += 1
                if total_its % plot_freq == 0:
                    loss_full = F(X, y, w, lam)
                    hist.append(loss_full)
                    it_count.append(total_its)
                    nizo.append(nizo_count)
                    nht.append(nht_count)
            outer_its += 1


    elif algorithm.startswith('sarah-zht'):
        outer_its = 0
        total_its = 0
        while (outer_its <= max_iter // update_freq) and (nizo_count < zomax):
            inner_its = 0
            anchor = w + 0.
            full_ghat = np.zeros(X.shape[1]).astype(np.float64)
            full_loss_anchor = F(X, y, anchor, lam)
            nizo_count += X.shape[0]
            for j in range(q):
                u = rs.randn(X.shape[1]).astype(np.float64)
                u /= np.linalg.norm(u)
                full_ghat += X.shape[1] * (F(X, y, anchor + mu * u, lam) - full_loss_anchor)/ mu * u
                nizo_count += X.shape[0]
            full_ghat /= q
            while (inner_its < update_freq) and (total_its < max_iter) and (nizo_count < zomax): 
                if inner_its == 0:
                    # if it is the first inner iteration, we actually do a true gradient step
                    pass
                else:
                    batch_idx = rs.randint(X.shape[0], size=batch_size)
                    X_batch, y_batch = X[batch_idx], y[batch_idx]
                    batch_loss = F(X_batch, y_batch, w, lam)
                    nizo_count += batch_size
                    ghat = np.zeros(X.shape[1]).astype(np.float64)
                    for j in range(q):
                        u = rs.randn(X.shape[1]).astype(np.float64)
                        u /= np.linalg.norm(u)
                        ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - batch_loss)/ mu * u
                        nizo_count += batch_size
                    ghat /= q
                    batch_loss_anchor = F(X_batch, y_batch, anchor, lam)
                    nizo_count += X_batch.shape[0]
                    ghat_anchor = np.zeros(X.shape[1], dtype=np.float64)
                    for j in range(q):
                        u = rs.randn(X.shape[1]).astype(np.float64)
                        u /= np.linalg.norm(u)
                        ghat_anchor += X.shape[1] * (F(X_batch, y_batch, anchor + mu * u, lam) - batch_loss_anchor)/ mu * u
                        nizo_count += batch_size
                    ghat_anchor /= q

                    full_ghat = ghat - ghat_anchor + full_ghat
                

                anchor = w + 0.

                w -= eta * (full_ghat)
                w = hard_threshold(w, k)
                nht_count += 1
                inner_its += 1
                total_its += 1
                if total_its % plot_freq == 0:
                    loss_full = F(X, y, w, lam)
                    hist.append(loss_full)
                    it_count.append(total_its)
                    nizo.append(nizo_count)
                    nht.append(nht_count)
            outer_its += 1


    elif algorithm.startswith('saga-zht'):
        i = 1

        while (i < max_iter + 1) and (nizo_count < zomax):


            if i == 1:
                # initialize the table
                table = np.zeros_like(X)
                for batch_idx in range(X.shape[0]):
                    ghat = np.zeros(X.shape[1])
                    X_batch, y_batch = X[batch_idx], y[batch_idx]
                    loss = F(X_batch, y_batch, w, lam)
                    nizo_count += batch_size
                    for j in range(q):
                        u = rs.randn(X.shape[1])
                        u /= np.linalg.norm(u)
                        ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - loss)/ mu * u
                        nizo_count += batch_size
                    ghat /= q
                    table[i] = ghat
                table_avg = 1/X.shape[0] * np.sum(table, axis=0)

                w -= eta * ghat


            else:

                batch_idx = rs.randint(0, X.shape[0], size=batch_size)
                X_batch, y_batch = X[batch_idx], y[batch_idx]
                ghat = np.zeros(X.shape[1])
                loss = F(X_batch, y_batch, w, lam)
                nizo_count += batch_size
                for j in range(q):
                    u = rs.randn(X.shape[1])
                    u /= np.linalg.norm(u)
                    ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - loss)/ mu * u
                    nizo_count += batch_size
                ghat /= q

                old_grad = table[batch_idx][0]  # I add [0] here because we do not have minibatches, just a batch of size 1
                w -= eta * (ghat - old_grad + table_avg)

                table[batch_idx] = ghat
                table_avg = table_avg + 1/X.shape[0]* (- old_grad + ghat)


            w = hard_threshold(w, k)
            nht_count += 1
            if i % plot_freq == 0:
                loss_full = F(X, y, w, lam)
                it_count.append(i)
                hist.append(loss_full)
                nizo.append(nizo_count)
                nht.append(nht_count)
            i+= 1


    elif algorithm.startswith('q-saga-zht'):
        i = 1
        num_p = int(algorithm.split('-')[-1])
        while (i < max_iter + 1) and (nizo_count < zomax):


            if i == 1:
                # initialize the table
                table = np.zeros_like(X)
                for batch_idx in range(X.shape[0]):
                    ghat = np.zeros(X.shape[1])
                    X_batch, y_batch = X[batch_idx], y[batch_idx]
                    loss = F(X_batch, y_batch, w, lam)
                    nizo_count += batch_size
                    for j in range(q):
                        u = rs.randn(X.shape[1])
                        u /= np.linalg.norm(u)
                        ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - loss)/ mu * u
                        nizo_count += batch_size
                    ghat /= q
                    table[i] = ghat
                table_avg = 1/X.shape[0] * np.sum(table, axis=0)

                w -= eta * ghat


            else:

                batch_idx = rs.randint(0, X.shape[0], size=batch_size)
                X_batch, y_batch = X[batch_idx], y[batch_idx]
                ghat = np.zeros(X.shape[1])
                loss = F(X_batch, y_batch, w, lam)
                nizo_count += batch_size
                for j in range(q):
                    u = rs.randn(X.shape[1])
                    u /= np.linalg.norm(u)
                    ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u, lam) - loss)/ mu * u
                    nizo_count += batch_size
                ghat /= q

                old_grad = table[batch_idx][0]  # added [0] here because we do not have minibatches, just a batch of size 1

                # We plot the table average
                table_avg = np.mean(table, axis=0)  # this may be optimized, but is computed in a bruteforce manner for simplicity. 
                w -= eta * (ghat - old_grad + table_avg)
                # here is the difference with vanilla SAGA: we will update potentially several entries in the table
                indices_for_update = rs.choice(range(X.shape[0]), size=num_p, replace=False)
                table[indices_for_update] = ghat


            w = hard_threshold(w, k)
            nht_count += 1
            if i % plot_freq == 0:
                loss_full = F(X, y, w, lam)
                it_count.append(i)
                hist.append(loss_full)
                nizo.append(nizo_count)
                nht.append(nht_count)
            i+= 1


    else:
        raise "Algorithm Unknown"

    print('Training done')
    return it_count, hist, nizo, nht

def hard_threshold(arr, k):
    top_k_indices = np.argpartition(np.abs(arr), -k)[-k:]
    thresholded_arr = np.zeros_like(arr)
    thresholded_arr[top_k_indices] = arr[top_k_indices]
    return thresholded_arr