import time
import numpy as np
import matplotlib.pyplot as plt
from DBO_helper_funcs import sort_with_index
from DBO_helper_funcs import levy
from algebraic_graph_calculate import algebraic_graph_calculate
from DBO_agent import et_dbo_agent

plt.close('all')
start_time = time.time()
v = 5
initialSize = 10
independent_global_iteration_DBO = 5

max_Iter = 5
DBO_regret_global = np.zeros((max_Iter, independent_global_iteration_DBO))
K = 1000
topology_list = ["fully-connected"]

for i_topology in topology_list:
    topology = i_topology
    plt.close('all')
    print(topology)
    for iter_global in np.arange(independent_global_iteration_DBO):
        print(iter_global)
        path1 = 'C:\\PycharmProjects\\pythonProject\\Levy\\overlapping\\fully_connected\\'
        path2 = str(iter_global)
        path3 = '\\'
        path4 = [""] * 4

        path4[0] = 'x.npy'
        path4[1] = 'y.npy'
        path4[2] = 'xx.npy'
        path4[3] = 'yy.npy'
        path = [""] * 4

        for i in np.arange(4):
            path[i] = path1 + path2 + path3 + path4[i]
            print(path[i])

        x = np.load(path[0])
        y = np.load(path[1])
        xx = np.load(path[2])
        yy = np.load(path[3])

        rows_x, cols_x = x.shape
        print("Size of x: ", x.size) # Size of x:  50
        print("Number of rows wrt x: ", rows_x) # Number of rows wrt x:  10
        print("Number of columns wrt x: ", cols_x) # Number of columns wrt x:  5
        # DBO   ------------------------- ------------------------------------------------------------------------------
        # initial definition              ------------------------------------------------------------------------------
        N_search = 2001
        M = 200
        m = 1
        N_data = 10
        dim = 1
        yy_array = yy.flatten()
        sorted_yy_array, sorted_indexes_yy_array = sort_with_index(yy_array)
        min_yy = sorted_yy_array[0]
        xx_min_location = xx[sorted_indexes_yy_array[0]]
        print("min_yy: ", min_yy)
        print("xx_min_location: ", xx_min_location)

        X_train = x
        # (10, 5)
        print("X_train: ", X_train.shape)
        Y_train = y
        # (10, 5)
        print("Y_train: ", Y_train.shape)

        adjacent_matrix, degree_matrix, laplacian_matrix = algebraic_graph_calculate(v, topology)
        sigma = np.ones((1, v))
        s = np.load('s.npy')
        print("s.shape: ", s.shape)
        b = np.load('b.npy')
        print("b.shape: ", b.shape)

        T_agent = np.zeros((v, M, M))
        W0_agent = np.zeros((M, v))
        tk = np.zeros((max_Iter, v, K + 1))
        n_M = np.arange(M)
        for itr in np.arange(max_Iter):
            Phi = np.zeros((v * (N_data + itr), M))
            n_Iter = np.arange(N_data + itr)
            for i in np.arange(v):
                for j in np.arange(N_data + itr):
                    features = np.sqrt(2/M) * np.cos(X_train[j, i] * s.T + b)
                    features = features / np.sqrt(np.dot(features, features.T))
                    Phi[(N_data + itr) * i + j, :] = features

            SS_block = np.zeros((v, N_data + itr, M))

            for i in np.arange(v):
                SS_block[i] = Phi[n_Iter + i * (N_data + itr)]

            Sigma = sigma[0, 0]
            SS = SS_block[0]
            for i in np.arange(1, v):
                Sigma = np.block([[Sigma, np.zeros((i, 1))],
                                  [np.zeros((1, i)), sigma[0, i]]])
                SS = np.block([[SS, np.zeros((i * (N_data + itr), M))],
                               [np.zeros((N_data + itr, i * M)), SS_block[i]]])

            SS_inv = np.linalg.inv(np.dot(SS.T, SS) + np.kron(Sigma, np.identity(m * dim * M)))
            # #---------------------------------------------------------------------------------------------------------
            eigenvalues_L = np.linalg.eigvals(laplacian_matrix)
            lambdaN = np.max(eigenvalues_L)
            print('lambdaN: ', lambdaN)

            sorted_eigenvalues_L, sorted_indexes_eigenvalues_L = sort_with_index(eigenvalues_L)
            lambda2 = sorted_eigenvalues_L[1]

            eigenvalues_T_agent = np.zeros((m * dim * M, v))
            Lambda = np.zeros((v, m * dim * M))
            max_eigenvalues_T_agent = np.zeros((1, v))
            min_eigenvalues_T_agent = np.zeros((1, v))

            for i in np.arange(v):
                T_agent[i] = np.dot(Phi[n_Iter + i * (N_data+itr)].T, Phi[n_Iter + i * (N_data+itr)]) + sigma[0, 0] * np.identity(m * dim * M)
                eigenvalues_T_agent[:, i] = np.linalg.eigvals(T_agent[i])
                max_eigenvalues_T_agent[0, i] = eigenvalues_T_agent[:, i].max()
                min_eigenvalues_T_agent[0, i] = eigenvalues_T_agent[:, i].min()

            Xi = np.max(max_eigenvalues_T_agent)
            print("Xi: ", Xi)
            print("\n")
            xi = np.min(min_eigenvalues_T_agent)
            print("xi: ", xi)
            print("\n")
            epsilon = 1
            Gamma = 0.625 * 1/lambdaN
            print("Gamma: ", Gamma)

            for i in np.arange(v):
                W0_agent[:, i] = np.dot(np.dot(np.linalg.inv(T_agent[i]), Phi[n_Iter+i*(N_data+itr)].T), Y_train[:, i])

            # DBO ------------------------------------------------------------------------------------------------------
            tk[itr, :, :], W, h, e, Wx, Wx_hat = et_dbo_agent(dim, adjacent_matrix, M, v, Gamma, T_agent, W0_agent, K,
                                                               topology, N_data)
            X_search = np.arange(-10, 10 + 20 / (N_search - 1), 20 / (N_search - 1)).reshape(N_search, 1)
            features_search = np.zeros((N_search, M))
            Y_search = np.zeros((N_search, 1))
            for i in np.arange(N_search):
                features_search[i, :] = np.sqrt(2/M) * np.cos(np.dot(X_search[i, :], s.T) + b)
                features_search[i, :] = features_search[i, :] / np.sqrt(np.dot(features_search[i, :].T, features_search[i, :]))
                Y_search[i, :] = np.dot(features_search[i, :], W[np.arange(M), K])

            Variation_search = np.zeros((N_search, v))

            for i in np.arange(v):
                for j in np.arange(N_search):
                    Variation_search[j, i] = np.dot(np.dot(features_search[j, :], np.linalg.inv(np.dot(
                        Phi[n_Iter + i * (N_data + itr)].T, Phi[n_Iter + i * (N_data + itr)]) + 1 * np.identity(M))),
                                             np.transpose(features_search[j, :]))

            alpha = 0.05 * np.log(1.01 * (itr + 1))

            Y_pred = np.zeros((N_search, v))
            minimum_Y_pred = np.zeros((1, v))
            pos = np.zeros((1, v), dtype=int)
            X_train_append = np.zeros((1, v))
            Y_train_append = np.zeros((1, v))
            for i in np.arange(v):
                Y_pred[:, i] = Y_search.ravel() - np.sqrt(alpha) * np.sqrt(Variation_search[:, i])
                sorted_Y_pred, sorted_indexes_Y_pred = sort_with_index(Y_pred[:, i])
                minimum_Y_pred[0, i] = sorted_Y_pred[0]
                pos[:, i] = sorted_indexes_Y_pred[0]
                X_train_append[:, i] = X_search[pos[:, i], :]
                Y_train_append[:, i] = levy(X_train_append[:, i])

            X_train = np.vstack((X_train, X_train_append))
            Y_train = np.vstack((Y_train, Y_train_append))

        Regret = np.zeros((max_Iter, v))
        Regret_sum = 0
        ave_Regret_DBO = np.zeros((max_Iter, 1))

        for itr in np.arange(max_Iter):
            print("itr: ", itr)
            for i in np.arange(v):
                Regret[itr, i] = Y_train[N_data + itr, i] - min_yy

            Regret_sum = Regret_sum + np.sum(Regret[itr, :])
            ave_Regret_DBO[itr, :] = Regret_sum / (itr + 1)

        print("ave_Regret_DBO: ", ave_Regret_DBO.shape)
        print("DBO_regret_global[:, iter_global]: ", DBO_regret_global[:, iter_global].shape)
        DBO_regret_global[:, iter_global] = ave_Regret_DBO.ravel()

        path5 = [""] * 2
        path5[0] = 'tk.npy'
        path5[1] = 'DBO_regret_global.npy'
        path6 = path1 + path2 + path3 + path5[0]
        path7 = path1 + path2 + path3 + path5[1]

        np.save(path6, tk)

    np.save('DBO_regret_global.npy', DBO_regret_global)
# ----------------------------------------------------------------------------------------------------------------------
end_time = time.time()
time_cost = end_time - start_time
print(f"Time taken: {time_cost} seconds")
