import time
import numpy as np
import matplotlib.pyplot as plt
from DBO_helper_funcs import sort_with_index
from DBO_helper_funcs import ackley
from algebraic_graph_calculate import algebraic_graph_calculate

plt.close('all')
start_time = time.time()
v = 5
initialSize = 10
independent_global_iteration_CTA = 5

max_Iter = 5
CTA_regret_global = np.zeros((max_Iter, independent_global_iteration_CTA))
K = 1000

topology_list = ["fully-connected"]

for i_topology in topology_list:
    topology = i_topology
    plt.close('all')
    for iter_global in np.arange(independent_global_iteration_CTA):
        print("iter_global: ", iter_global)
        path1 = 'C:\\PycharmProjects\\pythonProject\\Ackley\\overlapping\\fully_connected\\'
        path2 = str(iter_global)
        path3 = '\\'
        path4 = [""] * 4
        path4[0] = 'x.npy'
        path4[1] = 'y.npy'
        path4[2] = 'xx.npy'
        path4[3] = 'yy.npy'
        path = [""] * 4

        for i in np.arange(4):
            path[i] = path1 + path2 + path3 + path4[i]

        x = np.load(path[0])
        y = np.load(path[1])
        xx = np.load(path[2])
        yy = np.load(path[3])

        rows_x, cols_x = x.shape

        # --------------------------------------------------------------------------------------------------------------
        # initial definition              ------------------------------------------------------------------------------
        N_search = 2001
        M = 200
        m = 1
        N_data = 10
        dim = 1
        yy_array = yy.flatten()
        sorted_yy_array, sorted_indexes_yy_array = sort_with_index(yy_array)
        min_yy = sorted_yy_array[0]
        xx_min_location = xx[sorted_indexes_yy_array[0]]

        X_train = x
        # (10, 5)
        Y_train = y
        # (10, 5)

        adjacent_matrix, degree_matrix, laplacian_matrix = algebraic_graph_calculate(v, topology)

        C_matrix_CTA = np.zeros(v)
        max_degree = np.max(np.max(degree_matrix))
        d_max = np.zeros((v, v))
        CC = np.zeros((v, v))
        for i in np.arange(v):
            for j in np.arange(i+1, v):
                if adjacent_matrix[i, j] == 1:
                    d_max[i, j] = np.max(np.array([degree_matrix[i, i], degree_matrix[j, j]]))
                    d_max[j, i] = d_max[i, j]
                    CC[i, j] = 1/d_max[i, j]
                    CC[j, i] = CC[i, j]

            CC[i, i] = 1 - np.sum(CC[i, :])

        sigma = np.ones((1, v))
        s = np.load('s.npy')
        b = np.load('b.npy')

        T_agent = np.zeros((v, M, M))
        W0_agent = np.zeros((M, v))
        tk = np.zeros((max_Iter, v, K + 1))
        n_M = np.arange(M)
        for itr in np.arange(max_Iter):
            print("itr: ", itr)
            Phi = np.zeros((v * (N_data + itr), M))
            n_Iter = np.arange(N_data + itr)
            for i in np.arange(v):
                for j in np.arange(N_data + itr):
                    features = np.sqrt(2/M) * np.cos(X_train[j, i] * s.T + b)
                    features = features / np.sqrt(np.dot(features, features.T))
                    Phi[(N_data + itr) * i + j, :] = features

            SS_block = np.zeros((v, N_data + itr, M))

            for i in np.arange(v):
                SS_block[i] = Phi[n_Iter + i * (N_data + itr)]

            Sigma = sigma[0, 0]
            SS = SS_block[0]
            for i in np.arange(1, v):
                Sigma = np.block([[Sigma, np.zeros((i, 1))],
                                  [np.zeros((1, i)), sigma[0, i]]])
                SS = np.block([[SS, np.zeros((i * (N_data + itr), M))],
                               [np.zeros((N_data + itr, i * M)), SS_block[i]]])

            SS_inv = np.linalg.inv(np.dot(SS.T, SS) + np.kron(Sigma, np.identity(m * dim * M)))
            # #---------------------------------------------------------------------------------------------------------
            eigenvalues_L = np.linalg.eigvals(laplacian_matrix)
            lambdaN = np.max(eigenvalues_L)

            sorted_eigenvalues_L, sorted_indexes_eigenvalues_L = sort_with_index(eigenvalues_L)
            lambda2 = sorted_eigenvalues_L[1]

            eigenvalues_T_agent = np.zeros((m * dim * M, v))
            Lambda = np.zeros((v, m * dim * M))
            max_eigenvalues_T_agent = np.zeros((1, v))
            min_eigenvalues_T_agent = np.zeros((1, v))

            for i in np.arange(v):
                T_agent[i] = np.dot(Phi[n_Iter + i * (N_data+itr)].T, Phi[n_Iter + i * (N_data+itr)]) + sigma[0, 0] * np.identity(m * dim * M)
                eigenvalues_T_agent[:, i] = np.linalg.eigvals(T_agent[i])
                max_eigenvalues_T_agent[0, i] = eigenvalues_T_agent[:, i].max()
                min_eigenvalues_T_agent[0, i] = eigenvalues_T_agent[:, i].min()

            Xi = np.max(max_eigenvalues_T_agent)
            xi = np.min(min_eigenvalues_T_agent)
            epsilon = 1
            Gamma = 0.625 * 1/lambdaN

            for i in np.arange(v):
                W0_agent[:, i] = np.dot(np.dot(np.linalg.inv(T_agent[i]), Phi[n_Iter+i*(N_data+itr)].T), Y_train[:, i])
            u_CTA = 0.6
            d = np.zeros((v, N_data + itr))
            for i in np.arange(v):
                d[i, :] = Y_train[:, i].T

            # Distributed CTA LMS Algorithm
            W_CTA = W0_agent
            W_CTA = W_CTA.reshape(dim * M, v)
            e_CTA = np.zeros((N_data + itr, v))
            fai_CTA = np.zeros((dim * M, v))
            C_CTA = CC
            h_CTA = np.zeros((v * (N_data + itr), dim * M))

            for i in np.arange(v):
                a_CTA = np.arange((N_data + itr)*i, (N_data + itr)*(i + 1))
                h_CTA[a_CTA, :] = Phi[a_CTA, :]

            for T_CTA in np.arange(K + 1):
                for i in np.arange(v):
                    fai_CTA[:, i] = np.dot(W_CTA, C_CTA[:, i])

                for i in np.arange(v):
                    a_CTA = np.arange((N_data + itr) * i, (N_data + itr) * (i + 1))
                    e_CTA[:, i] = d[i, :].T - np.dot(h_CTA[a_CTA, :], W_CTA[:, i])
                    W_CTA[:, i] = fai_CTA[:, i] + u_CTA * (np.dot(h_CTA[a_CTA, :].T, (d[i, :].T - np.dot(h_CTA[a_CTA, :], fai_CTA[:, i]))) - 2 / (
                                    N_data + itr) * fai_CTA[:, i])

            X_search = np.arange(-10, 10 + 20 / (N_search - 1), 20 / (N_search - 1)).reshape(N_search, 1)
            features_search = np.zeros((N_search * dim, M))

            for i in np.arange(N_search):
                features_search[i, :] = np.sqrt(2/M) * np.cos(np.dot(X_search[i, :], s.T) + b)
                features_search[i, :] = features_search[i, :] / np.sqrt(np.dot(features_search[i, :].T, features_search[i, :]))

            Y_search = np.dot(features_search, W_CTA)

            Y_pred = np.zeros((N_search, v))
            minimum_Y_pred = np.zeros((1, v))
            pos = np.zeros((1, v), dtype=int)
            X_train_append = np.zeros((1, v))
            Y_train_append = np.zeros((1, v))
            for i in np.arange(v):
                Y_pred[:, i] = Y_search[:, i].ravel()
                sorted_Y_pred, sorted_indexes_Y_pred = sort_with_index(Y_pred[:, i])
                minimum_Y_pred[0, i] = sorted_Y_pred[0]
                pos[:, i] = sorted_indexes_Y_pred[0]
                X_train_append[:, i] = X_search[pos[:, i], :]
                Y_train_append[:, i] = ackley(X_train_append[:, i])

            X_train = np.vstack((X_train, X_train_append))
            Y_train = np.vstack((Y_train, Y_train_append))

        Regret = np.zeros((max_Iter, v))
        Regret_sum = 0
        ave_Regret_CTA = np.zeros((max_Iter, 1))

        for itr in np.arange(max_Iter):
            for i in np.arange(v):
                Regret[itr, i] = Y_train[N_data + itr, i] - min_yy

            Regret_sum = Regret_sum + np.sum(Regret[itr, :])
            ave_Regret_CTA[itr, :] = Regret_sum / (itr + 1)

        CTA_regret_global[:, iter_global] = ave_Regret_CTA.ravel()

    np.save('CTA_regret_global.npy', CTA_regret_global)
# ----------------------------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------------------------
end_time = time.time()
time_cost = end_time - start_time
print(f"Time taken: {time_cost} seconds")
