import numpy as np
from libsvmdata import fetch_libsvm
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import shuffle
from utils import one_sample_F_and_grad, full_batch_F_and_grad
import matplotlib.pyplot as plt
import importlib
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import pickle
import sys


rs = np.random.RandomState(42)
X_unshuffled, labels = fetch_libsvm("dna")
X_sparse, labels = shuffle(X_unshuffled, labels[:, None], random_state=42)
X = X_sparse.A

processor = make_pipeline(StandardScaler(), VarianceThreshold())
X = processor.fit_transform(X)
lam = 10
Y = OneHotEncoder().fit_transform(labels).A
n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
# We use the upper bound on the Hessian from https://trungvietvu.github.io/notes/2016/MLR
c_mat = np.mat = 1/2 *(np.eye(n_classes) - 1/n_classes * np.ones((n_classes, n_classes)))
L = np.linalg.norm(1/n_samples * np.kron(c_mat, X.T @ X) + 2 * lam * np.eye(n_classes * X.shape[1]), ord=2)
mu = lam


def projl2groupball(W, D):
    # group-wise projection onto the l2 ball
    norms = np.linalg.norm(W, axis=0, ord=2)
    mask_project = norms >= D
    denominator = ~mask_project + mask_project * norms / D
    W_proj = W / denominator
    return W_proj


def hard_threshold(arr, k):
    # Flatten the input tensor along all dimensions except the first one
    flattened_arr = arr.ravel()
    top_k_indices = np.argpartition(np.abs(flattened_arr), -k)[-k:]
    thresholded_arr = np.zeros_like(flattened_arr)
    thresholded_arr[top_k_indices] = flattened_arr[top_k_indices]
    # Reshape the result tensor back to the original shape
    reshaped_tensor = thresholded_arr.reshape(arr.shape)
    return reshaped_tensor


def run_ours_ZO(X, Y, k=150):
    n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
    max_iter = 50
    s2=d
    s = 3*k
    epsilonf = 2*d/(s2+2) * ((s-1)*(s2 - 1)/(d-1) + 3)
    alpha = 2
    miu = 0.000001

    kappa = L/mu
    omega = 1 - 1/(8*alpha*kappa)
    tau = 16*kappa*epsilonf/(alpha-1)
    results = dict()
    W = np.zeros((d, n_classes))
    results['losses'] = []
    rs = np.random.RandomState(42)
    num_samples = X.shape[0]
    eta = 1/(alpha *L)
    results['nht'], results['izo'] = [], []
    nht, izo = 0, 0
    results['izo'].append(izo)
    results['nht'].append(nht)
    loss, _ = full_batch_F_and_grad(W, X, Y, lam)
    results['losses'].append(loss)
    for i in range(max_iter):
        n_rand_dir = int(np.ceil(tau / omega**i))
        G_total = np.zeros_like(W)
        F, grad = full_batch_F_and_grad(W, X, Y, lam)
        izo += 1
        for j in range(n_rand_dir):
            eps = rs.randn(*W.shape)
            eps /= np.linalg.norm(eps)
            F_eps, _ = full_batch_F_and_grad(W + miu * eps, X, Y, lam)
            G_current = d/miu * (F_eps - F)* eps
            G_total = G_total + 1/(j+1) * (G_current - G_total)    # stable computation of the mean
            izo += 1
        W -=  eta * G_total
        W = hard_threshold(W, k)
        nht += 1
        results['izo'].append(izo)
        results['nht'].append(nht)
        loss, _ = full_batch_F_and_grad(W, X, Y, lam)
        results['losses'].append(loss)
    results['W'] = W
    results['eta'] = eta
    results['tau'] = tau
    return results


def run_theirs_ZO(X, Y, k=150):
    n_classes, n_samples, d = Y.shape[1], X.shape[0], X.shape[1]
    max_iter = 50
    miu = 0.000001
    s2=d
    s = 3*k
    results = dict()
    W = np.zeros((d, n_classes))
    results['losses'] = []
    rs = np.random.RandomState(42)
    num_samples = X.shape[0]
    epsilonf = 2*d/(s2+2) * ((s-1)*(s2 - 1)/(d-1) + 3)
    alpha = 2
    kappa = L/mu
    n_rand_dir = 16*kappa*epsilonf/(alpha-1)
    epsilon_F =  2*d/(n_rand_dir*(s2+2)) * ((s-1)*(s2 - 1)/(d-1) + 3) + 2
    eta = mu/((4*epsilon_F+1)*L**2)
    results['nht'], results['izo'] = [], []
    nht, izo = 0, 0
    results['izo'].append(izo)
    results['nht'].append(nht)
    loss, _ = full_batch_F_and_grad(W, X, Y, lam)
    results['losses'].append(loss)
    for i in range(max_iter):
        G_total = np.zeros_like(W)
        F, grad = full_batch_F_and_grad(W, X, Y, lam)
        izo += 1
        for j in range(int(np.ceil(n_rand_dir))):
            eps = rs.randn(*W.shape)
            eps /= np.linalg.norm(eps)
            F_eps, _ = full_batch_F_and_grad(W + miu * eps, X, Y, lam)
            G_current = d/miu * (F_eps - F) * eps
            G_total = G_total + 1/(j+1) * (G_current - G_total)   # stable computation of the mean
            izo += 1
        nht += 1
        W -=  eta * G_total 
        W = hard_threshold(W, k)
        results['izo'].append(izo)
        results['nht'].append(nht)
        loss, _ = full_batch_F_and_grad(W, X, Y, lam)
        results['losses'].append(loss)
    results['W'] = W
    results['eta'] = eta
    results['n_rand_dir'] = n_rand_dir
    return results


if __name__ == "__main__":
    k = int(sys.argv[1])

    results_theirszo = run_theirs_ZO(X, Y, k)
    with open(f'./results/results_theirszo_k_{k}.pkl', 'wb') as file:
        pickle.dump(results_theirszo, file)

    results_ourszo = run_ours_ZO(X, Y, k)
    with open(f'./results/results_ourszo_k_{k}.pkl', 'wb') as file:
        pickle.dump(results_ourszo, file)
