import numpy as np
from optimizer import inexact_newton, localSGD, fedAC, newton
from function import logistic
from data import loada9a

# Experiments with the a9a dataset from libSVM
F = logistic
X, Y = loada9a()
n, d = X.shape
Y = Y.reshape((n, 1))
mu = 0
M = 500

T = 1000


print("[*] Running inexact Newton w/ local SGD w/ momentum")
values = [(40, 0.1), (20, 0.1), (10, 0.2), (2, 1), (1, 2)]
for K, lr in values:
    print(f"R, lr = {int(T/K), lr}")
    momentum = 0.9

    Losses = []

    for rep in range(10):
        print(f"In {rep}-th repeteition")
        W, losses = inexact_newton(F, X, Y, int(T/K), 1.25, M, K,
                                   1, lr, momentum=momentum, mu=mu, quadSolver="localSGD", gap=1)
        Losses.append(losses)
        print(f"[+] Done, best_loss = {np.min(losses)}")

    np.save('runtime/' + f'/losses_newton_lsgd_{int(T/K)}_{lr}.npy', Losses)

# K_values = [10000, 2000, 1000, 200, 100]
# lr_values = [0.001, 0.0025, 0.005, 0.0075,
#              0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5]

# for K in K_values:
#     Losses = []
#     for lr in lr_values:
#         print(f"R, lr = ({int(T/K), lr})")
#         momentum = 0.9
#         W, losses = inexact_newton(F, X, Y, int(T/K), 1.25, M, K,
#                                    1, lr, momentum=momentum, mu=mu, quadSolver="localSGD", gap=1)
#         best_loss = np.min(losses)
#         print(f"best loss = {best_loss}")
#         Losses.append(best_loss)

#     print(f"Best lr for R={int(T/K)} is {lr_values[np.argmin(Losses)]}")
