import time
import numpy as np
from DBO_helper_funcs import sort_with_index
from algebraic_graph_calculate import algebraic_graph_calculate
from matplotlib import pyplot as plt
from DBO_helper_funcs import fitness_func_ionosphere
plt.close('all')

start_time = time.time()

dim = 2
N_data = 5

K = 500
num1 = 500
num2 = 200
N_search = num1 * num2


independent_global_iteration_ADMM = 5
max_Iter = 3
topology_list = ["fully-connected"]

for i_topology in topology_list:
    topology = i_topology
    plt.close('all')
    print(topology)
    for iter_global in np.arange(independent_global_iteration_ADMM):
        print("iter_global: ", iter_global)
        v = 5
        inf_value = float('inf')

        path1 = 'C:\\PycharmProjects\\pythonProject\\Ionosphere\\overlapping\\fully_connected\\'
        path2 = str(iter_global)
        path3 = '\\'
        path4 = [""] * 7
        path4[0] = 'X1.npy'
        path4[1] = 'X2.npy'
        path4[2] = 'X3.npy'
        path4[3] = 'X4.npy'
        path4[4] = 'X5.npy'
        path4[5] = 'testdata.npy'
        path4[6] = 'testlabel.npy'

        path01 = path1 + path2 + path3 + path4[0]
        path02 = path1 + path2 + path3 + path4[1]
        path03 = path1 + path2 + path3 + path4[2]
        path04 = path1 + path2 + path3 + path4[3]
        path05 = path1 + path2 + path3 + path4[4]
        path06 = path1 + path2 + path3 + path4[5]
        path07 = path1 + path2 + path3 + path4[6]

        X1 = np.load(path01)
        X2 = np.load(path02)
        X3 = np.load(path03)
        X4 = np.load(path04)
        X5 = np.load(path05)
        testdata = np.load(path06, allow_pickle=True)
        testlabel = np.load(path07, allow_pickle=True)

        fitness = inf_value * np.ones((N_data, v))
        for i in np.arange(N_data):
            fitness[i, 0] = fitness_func_ionosphere(testdata[0], testlabel[0], X1[i, :], 0)
            fitness[i, 1] = fitness_func_ionosphere(testdata[1], testlabel[1], X2[i, :], 1)
            fitness[i, 2] = fitness_func_ionosphere(testdata[2], testlabel[2], X3[i, :], 2)
            fitness[i, 3] = fitness_func_ionosphere(testdata[3], testlabel[3], X4[i, :], 3)
            fitness[i, 4] = fitness_func_ionosphere(testdata[4], testlabel[4], X5[i, :], 4)

        x = np.zeros((v, N_data, dim))
        x[0, :, :] = X1
        x[1, :, :] = X2
        x[2, :, :] = X3
        x[3, :, :] = X4
        x[4, :, :] = X5
        X_train0 = x
        M = 200
        m = 1
        print("X_train: ", X_train0.shape)
        X1_train = X_train0[0, :, :]
        X2_train = X_train0[1, :, :]
        X3_train = X_train0[2, :, :]
        X4_train = X_train0[3, :, :]
        X5_train = X_train0[4, :, :]
        X_train = np.vstack([X1_train, X2_train, X3_train, X4_train, X5_train])
        X_train_exhibition = np.hstack([X1_train, X2_train, X3_train, X4_train, X5_train])
        Y_train = fitness
        print("Y_train: ", Y_train.shape)
        adjacent_matrix, degree_matrix, laplacian_matrix = algebraic_graph_calculate(v, topology)
        sigma = np.ones((1, v))
        s = np.load('s.npy')
        print("s.shape: ", s.shape)
        b = np.load('b.npy')
        print("b.shape: ", b.shape)
        T_agent = np.zeros((v, dim * M, dim * M))
        W0_agent = np.zeros((dim * M, v))
        MCRs_total = np.zeros((max_Iter, v))
        best_parameter = np.zeros((max_Iter, dim * v))
        n_M = np.arange(M)
        for itr in np.arange(max_Iter):
            print("itr: ", itr)
            Phi_1 = np.zeros((dim * (N_data + itr), M))
            Phi_2 = np.zeros((dim * (N_data + itr), M))
            Phi_3 = np.zeros((dim * (N_data + itr), M))
            Phi_4 = np.zeros((dim * (N_data + itr), M))
            Phi_5 = np.zeros((dim * (N_data + itr), M))
            n_Iter = np.arange(N_data + itr)
            # ----------------------------------------------------------------------------------------------------------
            for i in np.arange(dim):
                for j in np.arange(N_data + itr):
                    features_1 = np.sqrt(2/M) * np.cos(X1_train[j, i] * s.T + b)
                    features_1 = features_1 / np.sqrt(np.dot(features_1, features_1.T))
                    Phi_1[(N_data + itr) * i + j, :] = features_1

            temp1 = np.zeros((dim, N_data + itr, M))
            for i in np.arange(dim):
                temp1[i, :, :] = Phi_1[np.arange((N_data + itr) * i, (N_data + itr) * (i+1)), :]

            Phi_1_new = np.hstack([temp1[0, :, :], temp1[1, :, :]])
            # ----------------------------------------------------------------------------------------------------------
            for i in np.arange(dim):
                for j in np.arange(N_data + itr):
                    features_2 = np.sqrt(2/M) * np.cos(X2_train[j, i] * s.T + b)
                    features_2 = features_2 / np.sqrt(np.dot(features_2, features_2.T))
                    Phi_2[(N_data + itr) * i + j, :] = features_2

            temp2 = np.zeros((dim, N_data + itr, M))
            for i in np.arange(dim):
                temp2[i, :, :] = Phi_2[np.arange((N_data + itr) * i, (N_data + itr) * (i+1)), :]

            Phi_2_new = np.hstack([temp2[0, :, :], temp2[1, :, :]])
            # ----------------------------------------------------------------------------------------------------------
            for i in np.arange(dim):
                for j in np.arange(N_data + itr):
                    features_3 = np.sqrt(2/M) * np.cos(X3_train[j, i] * s.T + b)
                    features_3 = features_3 / np.sqrt(np.dot(features_3, features_3.T))
                    Phi_3[(N_data + itr) * i + j, :] = features_3

            temp3 = np.zeros((dim, N_data + itr, M))
            for i in np.arange(dim):
                temp3[i, :, :] = Phi_3[np.arange((N_data + itr) * i, (N_data + itr) * (i+1)), :]

            Phi_3_new = np.hstack([temp3[0, :, :], temp3[1, :, :]])
            # ----------------------------------------------------------------------------------------------------------
            for i in np.arange(dim):
                for j in np.arange(N_data + itr):
                    features_4 = np.sqrt(2/M) * np.cos(X4_train[j, i] * s.T + b)
                    features_4 = features_4 / np.sqrt(np.dot(features_4, features_4.T))
                    Phi_4[(N_data + itr) * i + j, :] = features_4

            temp4 = np.zeros((dim, N_data + itr, M))
            for i in np.arange(dim):
                temp4[i, :, :] = Phi_4[np.arange((N_data + itr) * i, (N_data + itr) * (i+1)), :]

            Phi_4_new = np.hstack([temp4[0, :, :], temp4[1, :, :]])
            # ----------------------------------------------------------------------------------------------------------
            for i in np.arange(dim):
                for j in np.arange(N_data + itr):
                    features_5 = np.sqrt(2/M) * np.cos(X5_train[j, i] * s.T + b)
                    features_5 = features_5 / np.sqrt(np.dot(features_5, features_5.T))
                    Phi_5[(N_data + itr) * i + j, :] = features_5

            temp5 = np.zeros((dim, N_data + itr, M))
            for i in np.arange(dim):
                temp5[i, :, :] = Phi_5[np.arange((N_data + itr) * i, (N_data + itr) * (i+1)), :]

            Phi_5_new = np.hstack([temp5[0, :, :], temp5[1, :, :]])
            Phi = np.vstack([Phi_1_new, Phi_2_new, Phi_3_new, Phi_4_new, Phi_5_new])
            # ----------------------------------------------------------------------------------------------------------
            SS_block = np.zeros((v, N_data + itr, dim * M))
            for i in np.arange(v):
                SS_block[i] = Phi[n_Iter + i * (N_data + itr)]

            Sigma = sigma[0, 0]
            SS = SS_block[0]
            for i in np.arange(1, v):
                Sigma = np.block([[Sigma, np.zeros((i, 1))],
                                  [np.zeros((1, i)), sigma[0, i]]])
                SS = np.block([[SS, np.zeros((i * (N_data + itr), dim * M))],
                               [np.zeros((N_data + itr, i * dim * M)), SS_block[i]]])

            SS_inv = np.linalg.inv(np.dot(SS.T, SS) + np.kron(Sigma, np.identity(m * dim * M)))
            # #---------------------------------------------------------------------------------------------------------
            eigenvalues_L = np.linalg.eigvals(laplacian_matrix)
            lambdaN = np.max(eigenvalues_L)
            print('lambdaN: ', lambdaN)

            sorted_eigenvalues_L, sorted_indexes_eigenvalues_L = sort_with_index(eigenvalues_L)
            lambda2 = sorted_eigenvalues_L[1]

            eigenvalues_T_agent = np.zeros((m * dim * M, v))
            Lambda = np.zeros((v, m * dim * M))
            max_eigenvalues_T_agent = np.zeros((1, v))
            min_eigenvalues_T_agent = np.zeros((1, v))

            for i in np.arange(v):
                T_agent[i] = np.dot(Phi[n_Iter + i * (N_data+itr)].T, Phi[n_Iter + i * (N_data+itr)]) + sigma[0, 0] * np.identity(m * dim * M)
                eigenvalues_T_agent[:, i] = np.linalg.eigvals(T_agent[i])
                max_eigenvalues_T_agent[0, i] = eigenvalues_T_agent[:, i].max()
                min_eigenvalues_T_agent[0, i] = eigenvalues_T_agent[:, i].min()

            Xi = np.max(max_eigenvalues_T_agent)
            print("Xi: ", Xi)
            print("\n")
            xi = np.min(min_eigenvalues_T_agent)
            print("xi: ", xi)
            print("\n")
            epsilon = 1
            Gamma = 0.625 * 1/lambdaN
            print("Gamma: ", Gamma)


            for i in np.arange(v):
                W0_agent[:, i] = np.dot(np.dot(np.linalg.inv(T_agent[i]), Phi[n_Iter+i*(N_data+itr)].T), Y_train[:, i])

            # ADMM -----------------------------------------------------------------------------------------------------
            lambda_ADMM = sigma[0, 0] * v
            gamma_ADMM = 0.05
            h = np.zeros((v * (N_data + itr), dim * M))
            d = np.zeros((v, N_data + itr))

            for i in np.arange(v):
                d[i, :] = Y_train[:, i].T

            beta = W0_agent
            beta_hat = np.zeros(dim * M)
            t = np.zeros((dim * M, v))
            t_hat = np.zeros(dim * M)
            z = np.zeros((dim * M, K+1))
            W_ADMM = np.zeros((dim * M, v))

            # ADMM-based initialization
            for i in np.arange(v):
                a_ADMM = np.arange((N_data + itr) * i, (N_data + itr) * (i + 1))
                h[a_ADMM, :] = Phi[a_ADMM, :]
                beta_hat = beta_hat + 1/v * beta[:, i]
                t_hat = t_hat + 1/v * t[:, i]

            for T_ADMM in np.arange(1, K + 1):
                for i in np.arange(v):
                    beta[:, i] = np.dot(np.linalg.inv(np.dot(h[a_ADMM, :].T, h[a_ADMM, :]) + gamma_ADMM * np.identity(dim*M)), (np.dot(h[a_ADMM, :].T, d[i, :].T) - t[:, i] + gamma_ADMM * z[:, T_ADMM-1]))
                    beta_hat = beta_hat + np.dot(1/v, beta[:, i])
                    t_hat = t_hat + 1/v * t[:, i]

                for i in np.arange(v):
                    z[:, T_ADMM] = (np.dot(gamma_ADMM, beta_hat) + t_hat) / (lambda_ADMM/v + gamma_ADMM)
                    t[:, i] = t[:, i] + gamma_ADMM * (beta[:, i] - z[:, T_ADMM])

            W_ADMM = beta
            X_search = np.load("xx.npy")
            features_search0 = np.zeros((N_search * dim, M))
            Y_search = np.zeros((N_search, v))
            for i in np.arange(dim):
                for j in np.arange(N_search):
                    features_search_temp = np.sqrt(2 / M) * np.cos(np.dot(X_search[j, i], s.T) + b)
                    features_search_temp = features_search_temp / np.sqrt(
                        np.dot(features_search_temp, features_search_temp.T))
                    features_search0[N_search * i + j, :] = features_search_temp

            temp = np.zeros((dim, N_search, M))
            for i in np.arange(dim):
                temp[i, :, :] = features_search0[np.arange(N_search * i, N_search * (i+1)), :]

            features_search = np.hstack([temp[0, :, :], temp[1, :, :]])

            Y_search = np.dot(features_search, W_ADMM)
            Y_pred = np.zeros((N_search, v))
            minimum_Y_pred = np.zeros((1, v))
            pos = np.zeros((1, v), dtype=int)
            X_train_append = np.zeros((1, dim * v))
            Y_train_append = np.zeros((1, v))
            for i in np.arange(v):
                Y_pred[:, i] = Y_search[:, i].ravel()
                sorted_Y_pred, sorted_indexes_Y_pred = sort_with_index(Y_pred[:, i])
                pos[:, i] = sorted_indexes_Y_pred[0]
                X_train_append[:, np.arange(i*dim, (i+1)*dim)] = X_search[pos[:, i], :]
                Y_train_append[:, i] = fitness_func_ionosphere(testdata[i], testlabel[i], X_train_append[:, np.arange(i*dim, (i+1)*dim)].ravel(), i)

            X1_train = np.vstack([X1_train, X_train_append[:, np.arange(dim)]])
            X2_train = np.vstack([X2_train, X_train_append[:, np.arange(dim, 2*dim)]])
            X3_train = np.vstack([X3_train, X_train_append[:, np.arange(2*dim, 3*dim)]])
            X4_train = np.vstack([X4_train, X_train_append[:, np.arange(3*dim, 4*dim)]])
            X5_train = np.vstack([X5_train, X_train_append[:, np.arange(4*dim, 5*dim)]])
            X_train_exhibition = np.vstack([X_train_exhibition, X_train_append])

            MCRs = np.zeros((1, v))
            X_train = np.vstack((X1_train, X2_train, X3_train, X4_train, X5_train))
            Y_train = np.vstack((Y_train, Y_train_append))

            for i in np.arange(v):
                MCRs[:, i] = fitness_func_ionosphere(testdata[i], testlabel[i], X_train_append[:, np.arange(i*dim, (i+1)*dim)].ravel(), i)
                print("MCR of HO optimized SVM for agent ", i, "is ", MCRs[:, i])

            MCRs_total[itr, :] = MCRs
            best_parameter[itr, :] = X_train_append

        print("Hello!")
        path5 = [""] * 5
        path5[0] = 'tk.npy'
        path5[1] = 'best_parameter_ADMM.npy'
        path5[2] = 'MCRs_total_ADMM.npy'
        path5[3] = 'X_train_ADMM.npy'
        path5[4] = 'Y_train_ADMM.npy'

        path6 = path1 + path2 + path3 + path5[0]
        path7 = path1 + path2 + path3 + path5[1]
        path8 = path1 + path2 + path3 + path5[2]
        path9 = path1 + path2 + path3 + path5[3]
        path10 = path1 + path2 + path3 + path5[4]

        np.save(path7, best_parameter)
        np.save(path8, MCRs_total)
        np.save(path9, X_train)
        np.save(path10, Y_train)

# ----------------------------------------------------------------------------------------------------------------------
end_time = time.time()
time_cost = end_time - start_time
print(f"Time taken: {time_cost} seconds")
