import numpy as np
from optimizer import inexact_newton, localSGD, fedAC, newton
from function import logistic
from data import loada9a
from run_help import tune_lr

# A function to run the tuning experiment. This is essentially multiple
# calls to the tuning function in run_help with the correct arguments.


def tune(F, X, Y, mu, M, lrs, T, R_values):
    store = {"Newton": [], "LSGD": [], "MBSGD": [], "Newton_w_M": [],
             "LSGD_w_M": [], "MBSGD_w_M": [], "FEDAC1": [], "FEDAC2": []}

    for R in R_values:
        K = int(T/R)
        # Tuning Inexact Newton with Momentum
        args = {"F": F, "X": X, "Y": Y, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
                        "lr": 0, "momentum": 0.9, "mu": mu, "quadSolver": "localSGD", "damp": True}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            inexact_newton, "Inexact Newton w/ Momentum", lrs, **args)
        store["Newton_w_M"].append((min_lr, min_loss))

        # Tuning Inexact Newton
        args = {"F": F, "X": X, "Y": Y, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
                        "lr": 0, "momentum": 0, "mu": mu, "quadSolver": "localSGD", "damp": True}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            inexact_newton, "Inexact Newton", lrs, **args)
        store["Newton"].append((min_lr, min_loss))

        # Tuning Local SGD with Momentum
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": mu, "u": None, "lr": 0, "momentum": 0.9, "forHVP": False}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            localSGD, "Local SGD w/ Momentum", lrs, **args)
        store["LSGD_w_M"].append((min_lr, min_loss))

        # Tuning Local SGD
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": mu, "u": None, "lr": 0, "momentum": 0, "forHVP": False}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            localSGD, "Local SGD", lrs, **args)
        store["LSGD"].append((min_lr, min_loss))

        # Tuning MBSGD with Momentum
        args = {"F": F, "X": X, "Y": Y, "M": M*K, "K": 1, "R": R,
                        "mu": mu, "u": None, "lr": 0, "momentum": 0.9, "forHVP": False}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            localSGD, "Mini-batch SGD w/ Momentum", lrs, **args)
        store["MBSGD_w_M"].append((min_lr, min_loss))

        # Tuning MBSGD
        args = {"F": F, "X": X, "Y": Y, "M": M*K, "K": 1, "R": R,
                        "mu": mu, "u": None, "lr": 0, "momentum": 0, "forHVP": False}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            localSGD, "Mini-batch SGD", lrs, **args)
        store["MBSGD"].append((min_lr, min_loss))

        # Tuning FedAC-1
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": 0, "ver": 1, "u": None, "forHVP": False}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            fedAC, "FedAC-1", lrs, **args)
        store["FEDAC1"].append((min_lr, min_loss))

        # Tuning FedAC-1
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": 0, "ver": 2, "u": None, "forHVP": False}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            fedAC, "FedAC-2", lrs, **args)
        store["FEDAC2"].append((min_lr, min_loss))

    return store


F = logistic
X, Y = loada9a()
n, d = X.shape
Y = Y.reshape((n, 1))
mu_values = [1e-2, 1e-4, 1e-6, 1e-8]
M_values = [20, 50, 100, 200, 500]

lrs = [0.0001, 0.0005, 0.001, 0.002, 0.005, 0.01,
       0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50]

T = 1000  # i.e. K*R = 1000
if T == 1000:
    R_values = [1, 5, 10, 25, 50, 100, 500, 1000]
elif T == 100:
    R_values = [1, 5, 10, 25, 50, 100]

for mu in mu_values:
    print(f"[*] Running for mu = {mu}")
    for M in M_values:
        print(f"[*] Running for M = {M}")
        store = tune(F, X, Y, mu, M, lrs, T, R_values)
        np.save(f"results/{T}/store_{str(mu)}_{str(M)}.npy", store)
