import numpy as np
import copy

np.random.seed(1)


# generating data for different machines
def gen_data(M, N, d):
    data = []
    for m in range(M):
        a = np.random.multivariate_normal(np.zeros(d), np.eye(d))
        X = np.random.multivariate_normal(np.zeros(d), np.eye(d), N)
        y = X.dot(a)
        A = np.sum(np.array([np.outer(x, x) for x in X]), axis=0) / N
        b = np.sum(X * y[:, None], axis=0) / N
        data.append((X, y, A, b, a))
    return data


# Compute full gradient for given data matrix
def grad(A, b, w, theta):
    return A.dot(w + theta) - b


# Compute stochastic gradient for given data matrix
def stoch_grad(X, y, w, theta, N):
    i = np.random.randint(N)
    x = X[i]
    return ((w + theta).dot(x) - y[i]) * x


# running local SGD with and without personalization
def local_SGD(data, T, K, eta, alpha, personalize=True):
    M = len(data)
    N, d = data[0][0].shape

    losses = np.zeros((M, T))
    consensus = []
    g = np.zeros((M, d))
    w = np.zeros((M, d))
    a = np.zeros((M, d))
    theta = np.zeros((M, d))
    avg_consensus = []

    for m in range(M):
        a[m] = data[m][-1]

    for t in range(T):
        for m in range(M):
            # print(w[m])
            g[m] = stoch_grad(data[m][0], data[m][1], w[m], theta[m], N)
            if personalize:
                theta[m] -= eta * g[m] / M
            w[m] -= alpha * eta * g[m]

        w_avg = np.average(w, axis=0)

        if (t + 1) % K == 0:
            # print(t + 1)
            w = np.tile(w_avg, (M, 1))
            # for m in range(M):
            #     w[m] = copy.copy(w_avg)

        # rw_1 = (t * rw_1 + w_1) / (t + 1)
        # rw_2 = (t * rw_2 + w_2) / (t + 1)
        # rtheta_1 = (t * rtheta_1 + w_1) / (t + 1)
        # rtheta_2 = (t * rtheta_2 + w_2) / (t + 1)

        # for m in range(M):
        #     losses[m].append(np.linalg.norm(a[m] - w[m] - theta[m]) ** 2)
        losses = np.linalg.norm(a - w - theta, axis=1) ** 2
        consensus.append(np.average(np.linalg.norm(w - w_avg, axis=1) ** 2))

        if t == 0:
            avg_consensus.append(consensus[t])
        else:
            avg_consensus.append((avg_consensus[t - 1] * t + consensus[t]) / (t + 1))

    return losses, consensus, avg_consensus, w, theta


N = 1000
T = 100000
eta = 0.0001
alpha = 2

# d_choices = [5, 10, 25, 100]
d_choices = [25, 100]
M_choices = [2, 5, 10, 25, 100]
K_choices = [10, 100, 1000, 10000]

# d_choices = [10]
# M_choices = [100]
# K_choices = [1000, 10000]

for d in d_choices:
    for M in M_choices:
        for K in K_choices:
            print(d, M, K)
            data = gen_data(M, N, d)
            losses, consensus, avg_consensus, w, theta = local_SGD(
                data, T, K, eta, alpha
            )
            losses_np, consensus_np, avg_consensus_np, w_np, theta_np = local_SGD(
                data, T, K, eta, alpha, personalize=False
            )
            file = f"logs/log_M={M}_T={T}_K={K}_N={N}_d={d}_eta={eta}_alpha={alpha}.npy"
            np.save(
                file,
                (
                    losses,
                    consensus,
                    avg_consensus,
                    w,
                    theta,
                    losses_np,
                    consensus_np,
                    avg_consensus_np,
                    w_np,
                    theta_np,
                ),
            )


# rensus.append((np.linalg.norm(theta_1 - theta_2) ** 2) / 4.0)

# a_1 = np.random.multivariate_normal(np.zeros(d), np.eye(d))
# # print(a_1)

# a_2 = np.random.multivariate_normal(np.zeros(d), np.eye(d))
# # print(a_2)

# X_1 = np.random.multivariate_normal(np.zeros(d), np.eye(d), N)
# y_1 = X_1.dot(a_1)
# # print(X_1.shape, y_1.shape)
# X_2 = np.random.multivariate_normal(np.zeros(d), np.eye(d), N)
# y_2 = X_2.dot(a_2)
# # print(X_2[0], y_2[0])

# A_1 = np.sum(np.array([np.outer(x, x) for x in X_1]), axis=0) / N
# b_1 = np.sum(X_1 * y_1[:, None], axis=0) / N
# A_2 = np.sum(np.array([np.outer(x, x) for x in X_2]), axis=0) / N
# b_2 = np.sum(X_2 * y_2[:, None], axis=0) / N
# # print(A_1.shape, b_1.shape)


# print(grad(A_1, b_1, np.ones(d), np.zeros(d)).shape)

# T = 100000

# K_choices = [1, 2, 5, 10, 25, 50, 100, 125, 200, 250, 500, 1000, 1250, 2000, 2500, 5000]
# K_choices = [1000]
# max_consensus = []
# for K in K_choices:
#     print(K)
#     eta = 0.0001
#     alpha = 2

#     w_1 = np.zeros(d)
#     theta_1 = np.zeros(d)
#     w_2 = np.zeros(d)
#     theta_2 = np.zeros(d)
#     w_avg = np.zeros(d)
#     # local SGD with personalization
#     # rw_1 = np.zeros(d)
#     # rw_2 = np.zeros(d)
#     # rtheta_1 = np.zeros(d)
#     # rtheta_2 = np.zeros(d)

#     loss_1 = []
#     loss_2 = []
#     consensus = []
#     rensus = []
#     for t in range(T):
#         # g_1 = grad(A_1, b_1, w_1, theta_1)
#         # g_2 = grad(A_2, b_2, w_2, theta_2)
#         g_1 = stoch_grad(X_1, y_1, w_1, theta_1, N)
#         g_2 = stoch_grad(X_2, y_2, w_2, theta_2, N)
#         theta_1 -= eta * g_1 / 2
#         theta_2 -= eta * g_2 / 2
#         w_1 -= alpha * eta * g_1
#         w_2 -= alpha * eta * g_2
#         if (t + 1) % K == 0:
#             w_avg = (w_1 + w_2) / 2
#             w_1 = copy.copy(w_avg)
#             w_2 = copy.copy(w_avg)
#         # rw_1 = (t * rw_1 + w_1) / (t + 1)
#         # rw_2 = (t * rw_2 + w_2) / (t + 1)
#         # rtheta_1 = (t * rtheta_1 + w_1) / (t + 1)
#         # rtheta_2 = (t * rtheta_2 + w_2) / (t + 1)
#         # loss_1.append(np.linalg.norm(a_1 - w_1 - theta_1) ** 2)
#         # loss_2.append(np.linalg.norm(a_2 - w_2 - theta_2) ** 2)
#         consensus.append((np.linalg.norm(w_1 - w_2) ** 2) / 4.0)
#         # rensus.append((np.linalg.norm(theta_1 - theta_2) ** 2) / 4.0)
#         # print(w_1, w_2)
#     # print(check[:100])

#     # print(w_1+theta_1)
#     # print(w_2+theta_2)
#     # plt.plot(np.arange(T), np.log10(loss_1), label="agent_1")
#     # plt.plot(np.arange(T), np.log10(loss_2), label="agent_2")
#     # print(consensus)
#     # plt.plot(np.arange(T), np.array(consensus), label="consensus_w")
#     # # plt.plot(
#     # #     np.arange(T), np.log10(1e-15 + np.array(rensus)), label=r"consensus_$\theta$"
#     # # )
#     # plt.legend()
#     # plt.show()
#     # max_consensus.append(np.max(np.array(consensus)))

# # plt.plot(K_choices, max_consensus)
# # plt.show()

# max_consensus_lsgd = []
# for K in K_choices:
#     print(K)
#     eta = 0.0001
#     alpha = 2

#     w_1 = np.zeros(d)
#     theta_1 = np.zeros(d)
#     w_2 = np.zeros(d)
#     theta_2 = np.zeros(d)
#     w_avg = np.zeros(d)
#     # local SGD with personalization
#     # rw_1 = np.zeros(d)
#     # rw_2 = np.zeros(d)
#     # rtheta_1 = np.zeros(d)
#     # rtheta_2 = np.zeros(d)

#     loss_1 = []
#     loss_2 = []
#     consensus_lsgd = []
#     rensus = []
#     for t in range(T):
#         # g_1 = grad(A_1, b_1, w_1, theta_1)
#         # g_2 = grad(A_2, b_2, w_2, theta_2)
#         g_1 = stoch_grad(X_1, y_1, w_1, theta_1, N)
#         g_2 = stoch_grad(X_2, y_2, w_2, theta_2, N)
#         # theta_1 -= eta * g_1 / 2
#         # theta_2 -= eta * g_2 / 2
#         w_1 -= alpha * eta * g_1
#         w_2 -= alpha * eta * g_2
#         if (t + 1) % K == 0:
#             w_avg = (w_1 + w_2) / 2
#             w_1 = copy.copy(w_avg)
#             w_2 = copy.copy(w_avg)
#         # rw_1 = (t * rw_1 + w_1) / (t + 1)
#         # rw_2 = (t * rw_2 + w_2) / (t + 1)
#         # rtheta_1 = (t * rtheta_1 + w_1) / (t + 1)
#         # rtheta_2 = (t * rtheta_2 + w_2) / (t + 1)
#         # loss_1.append(np.linalg.norm(a_1 - w_1 - theta_1) ** 2)
#         # loss_2.append(np.linalg.norm(a_2 - w_2 - theta_2) ** 2)
#         consensus_lsgd.append((np.linalg.norm(w_1 - w_2) ** 2) / 4.0)
#         # rensus.append((np.linalg.norm(theta_1 - theta_2) ** 2) / 4.0)
#         # print(w_1, w_2)
#     # print(check[:100])

#     # print(w_1+theta_1)
#     # print(w_2+theta_2)
#     # plt.plot(np.arange(T), np.log(loss_1), label="agent_1")
#     # plt.plot(np.arange(T), np.log(loss_2), label="agent_2")
#     # print(consensus)
#     plt.plot(
#         np.arange(T), np.array(consensus_lsgd), label="consensus_w/o personalization"
#     )
#     plt.plot(np.arange(T), np.array(consensus), label="consensus_w/ personalization")
#     # plt.plot(
#     #     np.arange(T), np.log10(1e-15 + np.array(rensus)), label=r"consensus_$\theta$"
#     # )
#     plt.legend()
#     plt.savefig("vs_lsgd.png", dpi=400)
#     plt.show()
#     max_consensus_lsgd.append(np.max(np.array(consensus_lsgd)))

# # plt.plot(K_choices, max_consensus, label="w/ Personalization")
# # plt.plot(K_choices, max_consensus_lsgd, label="w/o Personalization")
# # plt.xlabel("Number of local steps")
# # plt.ylabel("Maximum consensus error")
# # plt.legend()
# # plt.savefig("consensus_w_K.png", dpi=400)
# # plt.show()
