import numpy as np
from optimizer import inexact_newton, localSGD, fedAC
from function import logistic
from run_help import repeat_run
from data import loada9a
import os


def repeat(F, X, Y, M, store, T, R_values, rep):
    Losses = {"Newton_w_M": [], "LSGD_w_M": [],
              "MBSGD_w_M": [], "FEDAC1": [], "FEDAC2": [], "Newton": [], "LSGD": [], "MBSGD": []}

    i = 0
    for R in R_values:
        K = int(T/R)

        # Repeating Inexact Newton with Momentum
        lr = store["Newton_w_M"][i][1]
        momentum = store["Newton_w_M"][i][0]
        args = {"F": F, "X": X, "Y": Y, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
                        "lr": lr, "momentum": momentum, "mu": 0, "quadSolver": "localSGD", "damp": True}
        losses = repeat_run(
            inexact_newton, "Inexact Newton w/ Momentum", rep, **args)
        Losses["Newton_w_M"].append(losses)

        # Repeating Inexact Newton
        lr = store["Newton"][i][0]
        args = {"F": F, "X": X, "Y": Y, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
                        "lr": lr, "momentum": 0, "mu": 0, "quadSolver": "localSGD", "damp": True}
        losses = repeat_run(inexact_newton, "Inexact Newton", rep, **args)
        Losses["Newton"].append(losses)

        # Repeating Local SGD with Momentum
        lr = store["LSGD_w_M"][i][1]
        momentum = store["LSGD_w_M"][i][0]
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": 0, "u": None, "lr": lr, "momentum": momentum, "forHVP": False}
        losses = repeat_run(localSGD, "Local SGD w/ Momentum", rep, **args)
        Losses["LSGD_w_M"].append(losses)

        # Repeating Local SGD
        lr = store["LSGD"][i][0]
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": 0, "u": None, "lr": lr, "momentum": 0, "forHVP": False}
        losses = repeat_run(localSGD, "Local SGD", rep, **args)
        Losses["LSGD"].append(losses)

        # Repeating MBSGD with Momentum
        lr = store["MBSGD_w_M"][i][1]
        momentum = store["MBSGD_w_M"][i][0]
        args = {"F": F, "X": X, "Y": Y, "M": M*K, "K": 1, "R": R,
                        "mu": 0, "u": None, "lr": lr, "momentum": momentum, "forHVP": False}
        losses = repeat_run(
            localSGD, "Mini-batch SGD w/ Momentum", rep, **args)
        Losses["MBSGD_w_M"].append(losses)

        # Repeating MBSGD
        lr = store["MBSGD"][i][0]
        args = {"F": F, "X": X, "Y": Y, "M": M*K, "K": 1, "R": R,
                        "mu": 0, "u": None, "lr": lr, "momentum": 0, "forHVP": False}
        losses = repeat_run(localSGD, "Mini-batch SGD", rep, **args)
        Losses["MBSGD"].append(losses)

        # Repeating FedAC-1
        mu = store['FEDAC1'][i][0]
        lr = store["FEDAC1"][i][1]
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": lr, "ver": 1, "u": None, "forHVP": False}
        losses = repeat_run(fedAC, "FedAC-1", rep, **args)
        Losses["FEDAC1"].append(losses)

        # Repeating FedAC-2
        mu = store['FEDAC2'][i][0]
        lr = store["FEDAC2"][i][1]
        args = {"F": F, "X": X, "Y": Y, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": lr, "ver": 2, "u": None, "forHVP": False}
        losses = repeat_run(fedAC, "FedAC-2", rep, **args)
        Losses["FEDAC2"].append(losses)

        i += 1

    return Losses


F = logistic
X, Y = loada9a()
n, d = X.shape
Y = Y.reshape((n, 1))
M_values = [20, 50, 100, 200, 500]

T = 100  # i.e. K*R = 1000
if T == 1000:
    R_values = [1, 5, 10, 25, 50, 100, 500, 1000]
elif T == 100:
    R_values = [1, 5, 10, 25, 50, 100]


for M in M_values:
    store = np.load(
        f"results/{T}/store_{M}.npy", allow_pickle=True).item()
    print(f"[*] Running for M = {M}")
    Losses = repeat(F, X, Y, M, store, T, R_values, 50)
    if os.path.exists(f"results/{T}/Losses_{str(M)}.npy"):
        Losses_old = np.load(
            f"results/{T}/Losses_{str(M)}.npy", allow_pickle=True).item()
        for key in Losses.keys():
            if len(Losses[key]) != 0:
                for i in range(len(R_values)):
                    Losses[key][i] += Losses_old[key][i]

    np.save(f"results/{T}/Losses_{str(M)}.npy", Losses)
