import numpy as np
from libsvmdata import fetch_libsvm
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import shuffle
from utils import one_sample_F_and_grad, full_batch_F_and_grad
import matplotlib.pyplot as plt
import importlib
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import pickle
import sys



def hard_threshold(arr, k):
    # Flatten the input tensor along all dimensions except the first one
    flattened_arr = arr.ravel()
    top_k_indices = np.argpartition(np.abs(flattened_arr), -k)[-k:]
    thresholded_arr = np.zeros_like(flattened_arr)
    thresholded_arr[top_k_indices] = flattened_arr[top_k_indices]
    # Reshape the result tensor back to the original shape
    reshaped_tensor = thresholded_arr.reshape(arr.shape)
    
    return reshaped_tensor

rs = np.random.RandomState(42)
X_unshuffled, labels = fetch_libsvm("dna")

X_sparse, labels = shuffle(X_unshuffled, labels[:, None], random_state=42)
X = X_sparse.A

processor = make_pipeline(StandardScaler(), VarianceThreshold())
X = processor.fit_transform(X)
lam = 10

Y = OneHotEncoder().fit_transform(labels).A
n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
# We use the upper bound on the Hessian from https://trungvietvu.github.io/notes/2016/MLR
c_mat = np.mat = 1/2 *(np.eye(n_classes) - 1/n_classes * np.ones((n_classes, n_classes)))
L = np.linalg.norm(1/n_samples * np.kron(c_mat, X.T @ X) + 2 * lam * np.eye(n_classes * X.shape[1]), ord=2)
mu = lam

def projl2groupball(W, D):
    # group-wise projection onto the l2 ball
    norms = np.linalg.norm(W, axis=0, ord=2)
    mask_project = norms >= D
    denominator = ~mask_project + mask_project * norms / D
    W_proj = W / denominator
    return W_proj

def runFG(X, Y, k=150, D=0.5):
    eta = 1/L
    max_iter = 50
    n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
    results = dict()
    W = np.zeros((d, n_classes))
    results['losses'] = []
    results['ifo'], results['nht'] = [], []
    ifo, nht = 0, 0
    results['ifo'].append(ifo)
    results['nht'].append(nht)
    loss, _ = full_batch_F_and_grad(W, X, Y, lam)
    results['losses'].append(loss)
    for i in range(max_iter):
        loss, grad = full_batch_F_and_grad(W, X, Y, lam)
        ifo += n_samples
        W -= eta * grad 
        W = hard_threshold(W, k)
        W = projl2groupball(W, D)  # project each group onto the l2 ball
        nht += 1
        results['ifo'].append(ifo)
        results['nht'].append(nht)
        loss, _ = full_batch_F_and_grad(W, X, Y, lam)
        results['losses'].append(loss)
    results['W'] = W
    return results

def runSG(X, Y, k=150, D=0.5):
    max_iter = 500
    rho = 0.9
    C = L
    eta = 1/(L + C)
    B = 100000
    tau = B*eta/C
    alpha = C/L + 1
    kappa = L/mu
    omega = 1 - 1/(4*alpha*kappa/rho)
    n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
    results = dict()
    W = np.zeros((d, n_classes))
    results['losses'] = []
    rs = np.random.RandomState(42)
    num_samples = X.shape[0]
    ifo, nht = 0, 0
    results['ifo'], results['nht'] = [], []
    results['ifo'].append(ifo)
    results['nht'].append(nht)
    loss, _ = full_batch_F_and_grad(W, X, Y, lam)
    results['losses'].append(loss)
    for i in range(max_iter):
        batch_size = int(np.ceil(tau / omega**i))
        if batch_size > num_samples:  
            batch_size = num_samples
        batch_idx = rs.choice(np.arange(n_samples), size=batch_size, replace=False)
        _, grad = full_batch_F_and_grad(W, X[batch_idx], Y[batch_idx], lam)
        ifo += batch_size
        W -=  eta * grad 
        W = hard_threshold(W, k)
        W = projl2groupball(W, D)  # project each group onto the l2 ball
        nht += 1
        results['ifo'].append(ifo)
        results['nht'].append(nht)
        loss, _ = full_batch_F_and_grad(W, X, Y, lam)
        results['losses'].append(loss)
    results['W'] = W
    return results

def runZO(X, Y, k=150, D=0.5):
    n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
    max_iter = 50

    rho = 0.9
    alpha = 2
    s2 = d
    s = 3*k
    kappa = L/mu
    miu = 0.000001
    epsilonf = 2*d/(s2+2) * ((s-1)*(s2 - 1)/(d-1) + 3)
    omega = 1 - 1/(8*alpha*kappa/rho)
    tau = 16*kappa*epsilonf/(alpha-1)
    results = dict()
    W = np.zeros((d, n_classes))
    results['losses'] = []
    rs = np.random.RandomState(42)
    num_samples = X.shape[0]
    eta = 1/(alpha*L)
    results['izo'], results['nht'] = [], []
    izo, nht = 0, 0
    results['izo'].append(izo)
    results['nht'].append(nht)
    loss, _ = full_batch_F_and_grad(W, X, Y, lam)
    results['losses'].append(loss)
    for i in range(max_iter):
        n_rand_dir = int(np.ceil(tau / omega**i))
        G_total = np.zeros_like(W)
        F, _ = full_batch_F_and_grad(W, X, Y, lam)
        izo += 1
        for j in range(n_rand_dir):
            eps = rs.randn(*W.shape)
            eps /= np.linalg.norm(eps)
            F_eps, _ = full_batch_F_and_grad(W + miu * eps, X, Y, lam)
            G_current = d/miu * (F_eps - F) * eps
            G_total = G_total + 1/(j+1) * (G_current - G_total)   # stable computation of the mean
            izo += 1
        W -=  eta * G_total 
        W = hard_threshold(W, k)
        W = projl2groupball(W, D)  # project each group onto the l2 ball
        nht += 1
        results['izo'].append(izo)
        results['nht'].append(nht)
        loss, _ = full_batch_F_and_grad(W, X, Y, lam)
        results['losses'].append(loss)
    results['W'] = W
    return results



if __name__ == "__main__":
    D = float(sys.argv[1])
    k = int(sys.argv[2])

    resultsFG = runFG(X, Y, k, D)
    with open(f'./results/results_fg_D_{D}_k_{k}.pkl', 'wb') as file:
        pickle.dump(resultsFG, file)

    resultsSG = runSG(X, Y, k, D)
    with open(f'./results/results_sg_D_{D}_k_{k}.pkl', 'wb') as file:
        pickle.dump(resultsSG, file)

    resultsZO = runZO(X, Y, k, D)
    with open(f'./results/results_zo_D_{D}_k_{k}.pkl', 'wb') as file:
        pickle.dump(resultsZO, file)