import argparse
import numpy as np
from datasets import get_data
from algorithms import optimize
import csv
import time

def benchmark_model(dataset, algo, nruns, max_iter, mu, q, eta, lam, k, beta, batch_size=None, update_freq=None, plot_freq=None, grid=None, zomax=None, pt='./results/'):
    X, y = get_data(dataset)
    random_state = np.random.RandomState(42)

    if grid is not None:
        print("Grid is not None: using the grid-search rather than the given eta.")
    else: 
        grid = [eta]
    
    best_loss = 100000000000000000000000000000000000000000
    for eta in grid:
        histories = []
        for _ in range(nruns):
            it_count, hist, nizo, nht = optimize(algo, random_state, X, y, max_iter, mu, q, eta, lam, k, beta,
                                                batch_size=batch_size, update_freq=update_freq, plot_freq=plot_freq, zomax=zomax)
            tstamp =  time.strftime("%D%H:%M", time.localtime(time.time())) 
            data = [dataset, algo, nruns, tstamp, lam, X.shape[1], eta]
            histories.append(hist)
        hist_array = np.array(histories)
        hist_mean = np.mean(hist_array, axis=0)
        hist_std = np.std(hist_array, axis=0)
        if hist_mean[-1] < best_loss:
            best_loss = hist_mean[-1]
            best_hist_mean = hist_mean
            best_hist_std = hist_std
            best_eta = eta
            best_data = data
    print(f'Algorithm: {algo}, best eta: {best_eta}')   
    
    # print(list(best_hist_mean))
    with open(f'{pt}/results_mean.csv', 'a') as file:
        writer = csv.writer(file)
        writer.writerow(np.concatenate([best_data, [str(list(best_hist_mean))]]))
    with open(f'{pt}/results_std.csv', 'a') as file:
        writer = csv.writer(file)
        writer.writerow(np.concatenate([best_data, [str(list(best_hist_std))]]))
    with open(f'{pt}/it_count.csv', 'a') as file:
        writer = csv.writer(file)
        writer.writerow(np.concatenate([best_data, [str(it_count)]]))
    with open(f'{pt}/nizo.csv', 'a') as file:
        writer = csv.writer(file)
        writer.writerow(np.concatenate([best_data, [str(nizo)]]))
    with open(f'{pt}/nht.csv', 'a') as file:
        writer = csv.writer(file)
        writer.writerow(np.concatenate([best_data, [str(nht)]]))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Benchmark sparse linear regression models')
    parser.add_argument('-d', '--dataset', help='Type of dataset to use')
    parser.add_argument('-a', '--algorithm', help='Algorithm to use')
    parser.add_argument('-n', '--nruns', type=int, help='Number of runs')
    parser.add_argument('-i', '--iter', type=int, help='Number of iterations')
    parser.add_argument('-e', '--eta', type=float, help='Learning rate')
    parser.add_argument('-q', '--q', type=int, help='Number of random directions')
    parser.add_argument('-m', '--mu', type=float, help='Random smoothing Radius')
    parser.add_argument('-l', '--lam', type=float, help='Regularization parameter')
    parser.add_argument('-k', '--k', type=int, help='Constrained sparsity')
    parser.add_argument('-b', '--batch_size', type=int, help='Batch size for SGD')
    parser.add_argument('-u', '--update_freq', type=int, help='Update frequency for SVRG')
    parser.add_argument('-p', '--plot_freq', type=int, help='Monitor the full loss each XX iteration')
    parser.add_argument('-g', '--grid', nargs='+', type=float, help='List of etas for the grid-search')
    parser.add_argument('-z', '--zomax', type=int, help='Max number of ZO oracle')
    parser.add_argument('-t', '--thepath', type=str, help='The path for storing the results')
    parser.add_argument('-B', '--beta', type=float, help='bias of svrzht')

    args = parser.parse_args()
    print("===alg start===")
    print(f'{args.mu}')
    benchmark_model(args.dataset, args.algorithm, args.nruns, args.iter, args.mu, args.q, args.eta, args.lam, args.k, args.beta,
                    batch_size=args.batch_size, update_freq=args.update_freq, plot_freq=args.plot_freq, grid=args.grid, zomax=args.zomax, pt=args.thepath)