import numpy as np
import json
import argparse
from json.decoder import JSONDecodeError

RS = np.random.RandomState(42)

def init_algo(
        n=100, d=50, k=5, n_tasks=1, rho=0.001, snr=0., eta=None, update_freq=None):
    r"""taken from https://github.com/benchopt/benchopt/blob/81d7d6f188d1886af999981b864a08c4976bf8bf/benchopt/datasets/simulated.py
    """
    rng = RS
    n_features = d
    n_samples = n 
    nnz = k

    if rho != 0:
        # X is generated cleverly using an AR model with reason corr and i
        # innovation sigma^2 = 1 - \rho ** 2: X[:, j+1] = rho X[:, j] + eps_j
        # where eps_j = sigma * rng.randn(n_samples)
        sigma = np.sqrt(1 - rho * rho)
        U = rng.randn(n_samples)

        X = np.empty([n_samples, n_features], order='F')
        X[:, 0] = U
        for j in range(1, n_features):
            U *= rho
            U += sigma * rng.randn(n_samples)
            X[:, j] = U
    else:
        X = rng.randn(n_samples, n_features)
    w_true = np.concatenate([RS.randn(int(k)), np.zeros(d-int(k))])[:, None]

    Y = X @ w_true

    if n_tasks == 1:
        return X, Y.flatten(), w_true.flatten()
    else:
        return X, Y, 

def F(x, y, w):
    return np.mean(np.linalg.norm(x @ w - y, axis=0)**2)

def run_algo(log):
    X, y, w = init_algo(n=log['N'], d=log['D'], k=log['k'], eta=log['eta'], update_freq=log['update_freq'])

    outer_its = 0
    total_its = 0

    eta = log['eta']
    mu = log['mu']
    batch_size = log['batch_size']
    zomax = log['zomax']
    q = log['q']
    k = log['k']
    nruns = log['n']
    log['hist'] = []
    for i in range(nruns):
        w = np.zeros(log['D'])

        print('Training...')
        hist, nizo, nht = [], [], []
        it_count = []
        loss_full = F(X, y, w)
        hist.append(loss_full)
        it_count.append(0)
        nizo_count = 0
        nht_count = 0
        nizo.append(nizo_count)
        nht.append(nht_count)

        while nizo_count < zomax:
            inner_its = 0
            anchor = w + 0.
            full_ghat = np.zeros(X.shape[1]).astype(np.float64)
            full_loss_anchor = F(X, y, anchor)
            nizo_count += X.shape[0]
            for j in range(q):
                u = RS.randn(X.shape[1]).astype(np.float64)
                u /= np.linalg.norm(u)
                full_ghat += X.shape[1] * (F(X, y, anchor + mu * u) - full_loss_anchor)/ mu * u
                nizo_count += X.shape[0]
            full_ghat /= q
            while (inner_its < log['update_freq']) and (nizo_count < zomax):  
                batch_idx = RS.randint(X.shape[0], size=batch_size)
                X_batch, y_batch = X[batch_idx], y[batch_idx]
                batch_loss = F(X_batch, y_batch, w)
                nizo_count += batch_size
                ghat = np.zeros(X.shape[1]).astype(np.float64)
                for j in range(q):
                    u = RS.randn(X.shape[1]).astype(np.float64)
                    u /= np.linalg.norm(u)
                    ghat += X.shape[1] * (F(X_batch, y_batch, w + mu * u) - batch_loss)/ mu * u
                ghat /= q
                batch_loss_anchor = F(X_batch, y_batch, anchor)
                nizo_count += X_batch.shape[0]
                ghat_anchor = np.zeros(X.shape[1], dtype=np.float64)
                for j in range(q):
                    u = RS.randn(X.shape[1]).astype(np.float64)
                    u /= np.linalg.norm(u)
                    ghat_anchor += X.shape[1] * (F(X_batch, y_batch, anchor + mu * u) - batch_loss_anchor)/ mu * u
                ghat_anchor /= q
                w -= eta * (ghat - ghat_anchor + full_ghat)
                w = hard_threshold(w, k)
                nht_count += 1
                inner_its += 1
                total_its += 1
                loss_full = F(X, y, w)
                hist.append(loss_full)
                it_count.append(total_its)
                nizo.append(nizo_count)
                nht.append(nht_count)
            outer_its += 1

        log['hist'].append(hist)
    log['nizo'] = nizo
    log['nht'] = nht

    filename = './results/output.json'

    with open(filename) as infile:
        try:
            prev_data = json.load(infile)
            prev_data[f"{log['q']}_{log['update_freq']}_{log['eta']}"] = log
        except JSONDecodeError:
            prev_data = dict()
            prev_data[f"{log['q']}_{log['update_freq']}_{log['eta']}"] = log

    with open(filename, 'w') as outfile:
        json.dump(prev_data, outfile)

def hard_threshold(arr, k):
    top_k_indices = np.argpartition(np.abs(arr), -k)[-k:]
    thresholded_arr = np.zeros_like(arr)
    thresholded_arr[top_k_indices] = arr[top_k_indices]
    return thresholded_arr

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Benchmark sparse linear regression models')
    parser.add_argument('-a', '--algorithm', help='Algorithm to use')
    parser.add_argument('-n', '--n', type=int, help='Number of runs', default=3)
    parser.add_argument('-N', '--N', type=int, help='Number of samples in the dataset', default=200)
    parser.add_argument('-D', '--D', type=int, help='Dimension of the dataset', default=10)
    parser.add_argument('-e', '--eta', type=float, help='Learning rate', default=0.001)
    parser.add_argument('-q', '--q', type=int, help='Number of random directions', default=20)
    parser.add_argument('-m', '--mu', type=float, help='Random smoothing Radius', default=0.0001)
    parser.add_argument('-k', '--k', type=int, help='Constrained sparsity', default=5)
    parser.add_argument('-b', '--batch_size', type=int, help='Batch size for SGD', default=1)
    parser.add_argument('-u', '--update_freq', type=int, help='Update frequency for SVRG', default=10)
    parser.add_argument('-z', '--zomax', type=int, help='Max number of ZO oracle', default=100000)
    args = parser.parse_args()
    log = vars(args)
    run_algo(log)

