import numpy as np
import json
import opacus
from opacus.accountants.utils import get_noise_multiplier
import argparse
from datasets_directory.dataset_loader import Mydatasets
from my_logistic_regression import MyLogisticRegression
from opt_algs  import newton, gd_priv, gd_priv_optls, private_newton, DoubleNoiseMech, CompareAlgs,sgd_priv

def helper_fun(datasetname,pb,num_rep):
    """ This function is a helper function for running different algorithms

    datasetname = name of the dataset
    pb = a dictionary with the parameters
    num_rep = number of times we repeat the optimization algorithm to report the average
    Tuning = True or False exhustive search for finding the best min eigenvalue
    """
    datasets = Mydatasets()
    X,y,w_opt = getattr(datasets,datasetname)()
    dataset = X,y
    priv_param = pb["total"]
    batch_size = pb["batch_size"]
    num_samples = len(y)
    delta = (1.0/num_samples)**2
    noise_multiplier = get_noise_multiplier(target_epsilon=priv_param, target_delta=delta, sample_rate=batch_size/num_samples, 
                                            epochs=None, steps=pb["num_iteration"], accountant='rdp', epsilon_tolerance=0.1)
    pb["noise_multiplier"] = noise_multiplier
    lr = MyLogisticRegression(X,y,reg=1e-9)
    c = CompareAlgs(lr,dataset,w_opt,iters=pb["num_iteration"],pb=pb)
    for rep in range(num_rep):
        print(str(rep+1)+" expriment out of "+ str(num_rep))
        c.add_algo(sgd_priv,"DPSGD")

        losses_dict = c.loss_vals()
        gradnorm_dict = c.gradnorm_vals()
        accuracy_dict = c.accuracy_vals()
        wall_clock_dict = c.wall_clock_alg()
        if rep == 0:
            losses_total = losses_dict
            gradnorm_total = gradnorm_dict
            accuracy_total = accuracy_dict
            wall_clock_total = wall_clock_dict
        else:
            for names in losses_total.keys():
                losses_total[names].extend(losses_dict[names])
                gradnorm_total[names].extend(gradnorm_dict[names])
                accuracy_total[names].extend(accuracy_dict[names])
                wall_clock_total[names].extend(wall_clock_dict[names])

    result = {}
    accuracy_wopt = c.accuracy_np()
    result['num-samples'] = num_samples
    result['acc-best'] = accuracy_wopt.tolist()
    for alg in losses_total.keys():
        losses = np.array(losses_total[alg])
        gradnorm = np.array(gradnorm_total[alg])
        acc = np.array(accuracy_total[alg])
        wall_clock = np.array(wall_clock_total[alg])
        result[alg] = {}
        result[alg]["loss_avg"] = (np.mean(losses, axis=0)).tolist()
        result[alg]["loss_std"] = (np.std(losses, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["gradnorm_avg"] = np.mean(gradnorm, axis=0).tolist()
        result[alg]["gradnorm_std"] = (np.std(gradnorm, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["acc_avg"] = (np.mean(acc, axis=0)).tolist()
        result[alg]["acc_std"] = (np.std(acc, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["clock_time_avg"] = np.mean(wall_clock, axis=0).tolist()
        result[alg]["clock_time_std"] =  (np.std(wall_clock, axis=0) / np.sqrt(num_rep)).tolist()

    json.dump(result, open("results-stochastic-new/sgd_"+datasetname+"_"+str(priv_param)+"_"+str(pb["num_iteration"])+"_"+str(pb["batch_size"])+".txt", 'w'))



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("datasetname")
    parser.add_argument("total")
    parser.add_argument("numiter")
    parser.add_argument("batch_size")
    args = parser.parse_args()
    datasetname = args.datasetname
    total = float(args.total) # total privacy budget 
    num_iter = int(args.numiter)  # number of iterations
    batch_size = int(args.batch_size)
    pb = {
      "total": total,  # Total privacy budget
      "num_iteration": num_iter,
      "batch_size": batch_size
    }
    num_rep = 10 # the number of repetitions for averaging over the randomness 
    print("the dataset is ", str(datasetname))
    print('total is '+str(total)+' num_iter '+str(num_iter))
    helper_fun(datasetname,pb,num_rep=num_rep)


if __name__ == '__main__':
    main()
