import numpy as np
from .coordinator import Coordinator
from .worker import Worker
from numpy import linalg as LA

class DSPGD(object):
    def __init__(self, n_clusters, n_components, gamma, rank):
        self.n_clusters = n_clusters
        self.n_components = n_components
        self.gamma = gamma
        self.rank = rank

    def feature_transform(self, process, raw_feature, ite, lamb=None, pre_eigvecs=None, pre_eigvals=None):
        if self.rank == 0:
            local_feature = None
        else:
            rand_feature = process.rand_fourier(raw_feature)
            local_feature = process.construction(rand_feature, 1./ite, lamb, pre_eigvecs, pre_eigvals)

        return local_feature

    def tridiagonal(self, diag1, diag2):
        diag1 = np.array(diag1)
        diag2 = np.array(diag2)
        tri_matrix = np.diag(diag1) + np.diag(diag2, 1) + np.diag(diag2, -1)

        return tri_matrix

    def distributed_power_iteration(self, process, local_feature, pre_feature, n_feature_set, comm_cost, vec, comm):
        if self.rank != 0:
            vec_set = None
        else:
            vec_set = []
            vec_set.append(vec)

        if self.rank == 0:
            start_idx = 0
            end_idx = 0
            for i in range(1, len(n_feature_set)):
                start_idx = start_idx + n_feature_set[i-1]
                end_idx = end_idx + n_feature_set[i]

                local_init_vec = vec[start_idx: end_idx]
                vec_set.append(local_init_vec)

        vec = comm.scatter(vec_set, root=0)

        if self.rank != 0:
            local_vec = np.matmul(local_feature.T, vec)
        else:
            local_vec = None
        local_vec_set = comm.gather(local_vec, root=0)
        if self.rank == 0:
            power_vec0 = process.aggregate(local_vec_set[1:])
            comm_cost[0] = comm_cost[0] + 1
            comm_cost[1] = comm_cost[1] + power_vec0.shape[0]
        else:
            power_vec0 = None

        power_vec0 = comm.bcast(power_vec0, root=0)
        if self.rank != 0:
            local_vec = np.matmul(local_feature, power_vec0)
        else:
            local_vec = None
        local_vec_set = comm.gather(local_vec, root=0)
        if self.rank == 0:
            power_vec = process.concatenation(local_vec_set[1:])
            comm_cost[2] = comm_cost[2] + power_vec.shape[0]
            if pre_feature is not None:
                add_vec = np.matmul(pre_feature, np.matmul(pre_feature.T, vec))
                power_vec = power_vec + add_vec
        else:
            power_vec = None

        return power_vec, comm_cost

    def Lanczos_full_reorthogonal_vanilla(self, process, local_feature, pre_feature, n_feature_set,
                                          init_vec, tol, comm, lamb=None, n_eigvectors=None):
        """
        Lanczos algorithm without modification for improving communication efficiency
        :param process:
        :param local_feature:
        :param init_vec:
        :param tol:
        :param comm:
        :param lamb:
        :param n_eigvectors:
        :return:
        """
        iteration_flag = True
        vec_q1 = init_vec
        # vec_q0 = np.zeros(shape=init_vec.shape[0])
        beta = []
        alpha = []
        n_eigs_old = 1
        eig_value = None
        eig_vector = None
        comm_cost = [0, 0, 0]

        # if self.rank != 0:
        #     n_local_feature = local_feature.shape[0]
        # else:
        #     n_local_feature = 0
        # n_feature_set = comm.gather(n_local_feature, root=0)
        # if self.rank == 0:
        #     start_idx = 0
        #     end_idx = 0
        #     for i in range(1, len(n_feature_set)):
        #         start_idx = start_idx + n_feature_set[i-1]
        #         end_idx = end_idx + n_feature_set[i]
        #
        #         local_init_vec = init_vec[start_idx: end_idx]
        #         init_vec_set.append(local_init_vec)
        #
        # init_vec = comm.scatter(init_vec_set, root=0)
        #
        # if self.rank != 0:
        #     local_vec = np.matmul(local_feature.T, init_vec)
        # else:
        #     local_vec = None
        # local_vec_set = comm.gather(local_vec, root=0)
        # if self.rank == 0:
        #     vec_z0 = process.aggregate(local_vec_set[1:])
        #     comm_cost[0] = comm_cost[0] + 1
        #     comm_cost[1] = comm_cost[1] + vec_z0.shape[0]
        # else:
        #     vec_z0 = None
        #
        # vec_z0 = comm.bcast(vec_z0, root=0)
        # if self.rank != 0:
        #     local_vec = np.matmul(local_feature, vec_z0)
        # else:
        #     local_vec = None
        # local_vec_set = comm.gather(local_vec, root=0)
        vec_z, comm_cost = self.distributed_power_iteration(process, local_feature, pre_feature, n_feature_set,
                                                            comm_cost, init_vec, comm)
        if self.rank == 0:
            if pre_feature is not None:
                max_rank = self.n_components + pre_feature.shape[1]
            else:
                max_rank = self.n_components

            alpha.append(np.dot(vec_q1, vec_z))
            # Use orthogonalization twice
            vec_z = vec_z - np.dot(vec_z, vec_q1) * vec_q1
            vec_z = vec_z - np.dot(vec_z, vec_q1) * vec_q1
            beta.append(LA.norm(vec_z))

            vec_q0 = np.copy(vec_q1)
            vec_q1 = vec_z / beta[-1]

            matrix_Q = np.concatenate((vec_q0.reshape(-1, 1), vec_q1.reshape(-1, 1)), axis=1)
            eig_value_old = alpha[-1]
        else:
            max_rank = None
            matrix_Q = None
            eig_value_old = None
            vec_q1 = None

        if n_eigvectors == None:
            while 1:
                iteration_flag = comm.bcast(iteration_flag, root=0)
                if not iteration_flag:
                    break
                vec_z, comm_cost = self.distributed_power_iteration(process, local_feature, pre_feature, n_feature_set,
                                                                    comm_cost, vec_q1, comm)

                if self.rank == 0:
                    alpha.append(np.dot(vec_q1, vec_z))

                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    beta.append(LA.norm(vec_z))

                    # vec_q0 = np.copy(vec_q1)
                    vec_q1 = vec_z / beta[-1]

                    matrix_T = self.tridiagonal(alpha, beta[:-1])


                    U, S, V = LA.svd(matrix_T, full_matrices=False)

                    n_eigs = np.sum(S >= lamb)
                    eig_value = S[:n_eigs]
                    eig_vector = np.matmul(matrix_Q, U[:, :n_eigs])

                    if matrix_Q.shape[1] >= max_rank:
                        iteration_flag = False
                        continue

                    matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)

                    if n_eigs != n_eigs_old:
                        n_eigs_old = n_eigs
                        eig_value_old = np.copy(eig_value)
                    else:
                        n_converge = np.sum(np.abs(eig_value-eig_value_old)<=tol)
                        n_eigs_old = n_eigs
                        eig_value_old = np.copy(eig_value)

                        if n_converge == n_eigs:
                            iteration_flag = False
                            continue

        else:
            while 1:
                iteration_flag = comm.bcast(iteration_flag, root=0)
                if not iteration_flag:
                    break
                vec_z, comm_cost = self.distributed_power_iteration(process, local_feature, pre_feature, n_feature_set,
                                                                    comm_cost, vec_q1, comm)

                if self.rank == 0:
                    alpha.append(np.dot(vec_q1, vec_z))

                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    beta.append(LA.norm(vec_z))

                    # vec_q0 = np.copy(vec_q1)
                    vec_q1 = vec_z / beta[-1]

                    matrix_T = self.tridiagonal(alpha, beta[:-1])

                    # print(matrix_T)

                    U, S, V = LA.svd(matrix_T, full_matrices=False)

                    # print('U:', U)

                    if S.shape[0] <= n_eigvectors:
                        eig_value_old = np.copy(S)
                        matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)
                        continue

                    eig_value = S[:n_eigvectors]
                    # print('Q:', matrix_Q)
                    eig_vector = np.matmul(matrix_Q, U[:, :n_eigvectors])

                    if matrix_Q.shape[1] >= max_rank:
                        iteration_flag = False
                        continue

                    matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)

                    n_converge = np.sum(np.abs(eig_value - eig_value_old) <= tol)
                    eig_value_old = np.copy(eig_value)
                    if n_converge == n_eigvectors:
                        iteration_flag = False
                        continue

        return eig_value, eig_vector, comm_cost

    def Lanczos_full_reorthogonal(self, process, local_feature, init_vec, tol, comm, lamb=None, n_eigvectors=None):
        """
        Lanczos algorithm utilized in CEM
        Use the change of eigenvalues as the convergence criterion
        :param process:
        :param local_feature:
        :param init_vec:
        :param tol:
        :param comm:
        :return:
        """

        iteration_flag = True
        vec_q1 = init_vec
        vec_q0 = np.zeros(shape=init_vec.shape[0])
        beta = []
        alpha = []
        n_eigs_old = 1
        eig_value = None
        eig_vector = None
        comm_cost = [0,0]

        max_rank = init_vec.shape[0]

        if self.rank != 0:
            local_vec, _ = process.local_computation(local_feature, vec_q1)
        else:
            local_vec = None

        local_vec_set = comm.gather(local_vec, root=0)
        if self.rank == 0:
            vec_z = process.aggregate(local_vec_set[1:])
            comm_cost[0] = comm_cost[0] + 1
            comm_cost[1] = comm_cost[1] + vec_z.shape[0]
            alpha.append(np.dot(vec_q1, vec_z))

            # Use orthogonalization twice
            vec_z = vec_z - np.dot(vec_z, vec_q1) * vec_q1
            vec_z = vec_z - np.dot(vec_z, vec_q1) * vec_q1
            beta.append(LA.norm(vec_z))

            vec_q0 = np.copy(vec_q1)
            vec_q1 = vec_z / beta[-1]

            matrix_Q = np.concatenate((vec_q0.reshape(-1, 1), vec_q1.reshape(-1, 1)), axis=1)
            eig_value_old = alpha[-1]
        else:
            matrix_Q = None
            eig_value_old = None

        if n_eigvectors == None:
            while 1:
                iteration_flag = comm.bcast(iteration_flag, root=0)
                if not iteration_flag:
                    break
                vec_q1 = comm.bcast(vec_q1, root=0)
                if self.rank != 0:
                    local_vec, _ = process.local_computation(local_feature, vec_q1)

                local_vec_set = comm.gather(local_vec, root=0)
                if self.rank == 0:
                    vec_z = process.aggregate(local_vec_set[1:])
                    comm_cost[0] = comm_cost[0] + 1
                    comm_cost[1] = comm_cost[1] + vec_z.shape[0]
                    alpha.append(np.dot(vec_q1, vec_z))

                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    beta.append(LA.norm(vec_z))

                    # vec_q0 = np.copy(vec_q1)
                    vec_q1 = vec_z / beta[-1]

                    matrix_T = self.tridiagonal(alpha, beta[:-1])


                    U, S, V = LA.svd(matrix_T, full_matrices=False)

                    n_eigs = np.sum(S >= lamb)
                    eig_value = S[:n_eigs]
                    eig_vector = np.matmul(matrix_Q, U[:, :n_eigs])

                    if matrix_Q.shape[1] >= max_rank:
                        iteration_flag = False
                        continue

                    matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)

                    if n_eigs != n_eigs_old:
                        n_eigs_old = n_eigs
                        eig_value_old = np.copy(eig_value)
                    else:
                        n_converge = np.sum(np.abs(eig_value-eig_value_old)<=tol)
                        n_eigs_old = n_eigs
                        eig_value_old = np.copy(eig_value)

                        if n_converge == n_eigs:
                            iteration_flag = False
                            continue

        else:
            while 1:
                iteration_flag = comm.bcast(iteration_flag, root=0)
                if not iteration_flag:
                    break
                vec_q1 = comm.bcast(vec_q1, root=0)
                if self.rank != 0:
                    local_vec, _ = process.local_computation(local_feature, vec_q1)

                local_vec_set = comm.gather(local_vec, root=0)
                if self.rank == 0:
                    vec_z = process.aggregate(local_vec_set[1:])
                    comm_cost[0] = comm_cost[0] + 1
                    comm_cost[1] = comm_cost[1] + vec_z.shape[0]
                    alpha.append(np.dot(vec_q1, vec_z))

                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    vec_z = vec_z - np.matmul(matrix_Q, np.matmul(matrix_Q.T, vec_z))
                    beta.append(LA.norm(vec_z))

                    # vec_q0 = np.copy(vec_q1)
                    vec_q1 = vec_z / beta[-1]

                    matrix_T = self.tridiagonal(alpha, beta[:-1])

                    # print(matrix_T)

                    U, S, V = LA.svd(matrix_T, full_matrices=False)

                    # print('U:', U)

                    if S.shape[0] <= n_eigvectors:
                        eig_value_old = np.copy(S)
                        matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)
                        continue

                    eig_value = S[:n_eigvectors]
                    # print('Q:', matrix_Q)
                    eig_vector = np.matmul(matrix_Q, U[:, :n_eigvectors])

                    if matrix_Q.shape[1] >= max_rank:
                        iteration_flag = False
                        continue

                    matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)

                    n_converge = np.sum(np.abs(eig_value - eig_value_old) <= tol)
                    eig_value_old = np.copy(eig_value)
                    if n_converge == n_eigvectors:
                        iteration_flag = False
                        continue

        comm_cost = comm.bcast(comm_cost, root=0)

        return eig_value, eig_vector, comm_cost

    def Lanczos_no_reorthogonal(self, process, local_feature, init_vec, tol, comm, lamb=None, n_eigvectors=None):
        """
        Lanczos algorithm utilized in CEM
        Use the change of eigenvalues as the convergence criterion
        :param process:
        :param local_feature:
        :param init_vec:
        :param tol:
        :param comm:
        :param lamb:
        :param n_eigvectors:
        :return:
        """

        iteration_flag = True
        vec_q1 = init_vec
        vec_q0 = np.zeros(shape=init_vec.shape[0])
        beta = []
        alpha = []
        n_eigs_old = 1
        eig_value = None
        eig_vector = None

        if self.rank != 0:
            local_vec, _ = process.local_computation(local_feature, vec_q1)
        else:
            local_vec = None

        local_vec_set = comm.gather(local_vec, root=0)
        if self.rank == 0:
            vec_z = process.aggregate(local_vec_set[1:])
            alpha.append(np.dot(vec_q1, vec_z))

            vec_z = vec_z - alpha[-1] * vec_q1
            beta.append(LA.norm(vec_z))

            vec_q0 = np.copy(vec_q1)
            vec_q1 = vec_z / beta[-1]

            matrix_Q = np.concatenate((vec_q0.reshape(-1,1), vec_q1.reshape(-1,1)), axis=1)
            eig_value_old = alpha[-1]
        else:
            matrix_Q = None
            eig_value_old = None

        if n_eigvectors == None:
            while 1:
                iteration_flag = comm.bcast(iteration_flag, root=0)
                if not iteration_flag:
                    break
                vec_q1 = comm.bcast(vec_q1, root=0)
                if self.rank != 0:
                    local_vec, _ = process.local_computation(local_feature, vec_q1)

                local_vec_set = comm.gather(local_vec, root=0)
                if self.rank == 0:
                    vec_z = process.aggregate(local_vec_set[1:])
                    alpha.append(np.dot(vec_q1, vec_z))

                    vec_z = vec_z - alpha[-1] * vec_q1 - beta[-1] * vec_q0
                    beta.append(LA.norm(vec_z))

                    vec_q0 = np.copy(vec_q1)
                    vec_q1 = vec_z / beta[-1]

                    matrix_T = self.tridiagonal(alpha, beta[:-1])


                    U, S, V = LA.svd(matrix_T, full_matrices=False)

                    n_eigs = np.sum(S >= lamb)
                    eig_value = S[:n_eigs]
                    eig_vector = np.matmul(matrix_Q, U[:, :n_eigs])

                    if matrix_Q.shape[1] >= self.n_components:
                        iteration_flag = False
                        continue

                    matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)

                    if n_eigs != n_eigs_old:
                        n_eigs_old = n_eigs
                        eig_value_old = np.copy(eig_value)
                    else:
                        n_converge = np.sum(np.abs(eig_value-eig_value_old)<=tol)
                        n_eigs_old = n_eigs
                        eig_value_old = np.copy(eig_value)

                        if n_converge == n_eigs:
                            iteration_flag = False
                            continue

        else:
            while 1:
                iteration_flag = comm.bcast(iteration_flag, root=0)
                if not iteration_flag:
                    break
                vec_q1 = comm.bcast(vec_q1, root=0)
                if self.rank != 0:
                    local_vec, _ = process.local_computation(local_feature, vec_q1)

                local_vec_set = comm.gather(local_vec, root=0)
                if self.rank == 0:
                    vec_z = process.aggregate(local_vec_set[1:])
                    alpha.append(np.dot(vec_q1, vec_z))

                    vec_z = vec_z - alpha[-1] * vec_q1 - beta[-1] * vec_q0
                    beta.append(LA.norm(vec_z))

                    vec_q0 = np.copy(vec_q1)
                    vec_q1 = vec_z / beta[-1]

                    matrix_T = self.tridiagonal(alpha, beta[:-1])

                    # print(matrix_T)

                    U, S, V = LA.svd(matrix_T, full_matrices=False)

                    print('U:', U)

                    if S.shape[0] <= n_eigvectors:
                        eig_value_old = np.copy(S)
                        matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)
                        continue

                    eig_value = S[:n_eigvectors]
                    print('Q:', matrix_Q)
                    eig_vector = np.matmul(matrix_Q, U[:, :n_eigvectors])

                    if matrix_Q.shape[1] >= self.n_components:
                        iteration_flag = False
                        continue

                    matrix_Q = np.concatenate((matrix_Q, vec_q1.reshape(-1, 1)), axis=1)

                    n_converge = np.sum(np.abs(eig_value - eig_value_old) <= tol)
                    eig_value_old = np.copy(eig_value)
                    if n_converge == n_eigvectors:
                        iteration_flag = False
                        continue

        return eig_value, eig_vector

    def DSPGD_update_CEM(self, process, feature, ite, dim, comm, tol,
                         lamb=None, pre_eigvals=None, pre_eigvecs=None, vec_size=None):
        if lamb is None:
            if self.rank == 0:
                generator = np.random.RandomState(0)
                init_vec = generator.normal(size=self.n_components)
                init_vec = init_vec / LA.norm(init_vec)

            else:
                generator = None
                init_vec = None

            init_vec = comm.bcast(init_vec, root=0)
            local_feature = self.feature_transform(process, feature, ite)

            eig_value, eig_vector, comm_cost = self.Lanczos_full_reorthogonal(process, local_feature, init_vec, tol,
                                                                              comm, n_eigvectors=dim)

            eig_value = comm.bcast(eig_value, root=0)
            eig_vector = comm.bcast(eig_vector, root=0)

            if self.rank == 0:
                local_eigvecs = None
            else:
                local_eigvecs = np.matmul(np.matmul(local_feature, eig_vector), np.diag(1/np.sqrt(eig_value)))

            return eig_value, local_eigvecs, comm_cost

        else:
            local_feature = self.feature_transform(process, feature, ite, lamb, pre_eigvecs, pre_eigvals)
            if self.rank == 0:
                generator = np.random.RandomState(0)
                init_vec = generator.normal(size=vec_size)
                init_vec = init_vec / LA.norm(init_vec)
            else:
                generator = None
                init_vec = None

            init_vec = comm.bcast(init_vec, root=0)
            eig_value, eig_vector, comm_cost = self.Lanczos_full_reorthogonal(process, local_feature, init_vec, tol,
                                                                              comm, lamb=lamb)

            eig_value = comm.bcast(eig_value, root=0)
            eig_vector = comm.bcast(eig_vector, root=0)

            if self.rank == 0:
                local_eigvecs = None
            else:
                local_eigvecs = np.matmul(np.matmul(local_feature, eig_vector), np.diag(1 / np.sqrt(eig_value)))

            return eig_value, local_eigvecs, comm_cost

    def DSPGD_update_wo_CEM(self, process, feature, ite, dim, comm, tol,
                            lamb=None, pre_eigvalue=None, pre_eigvector=None):
        # Construct an indicator vector for the division of the initial vector
        if self.rank != 0:
            n_local_feature = feature.shape[0]
        else:
            n_local_feature = 0
        n_feature_set = comm.gather(n_local_feature, root=0)
        if self.rank == 0:
            n_total_feature = 0
            for n_feature in n_feature_set[1:]:
                n_total_feature = n_total_feature + n_feature
        else:
            n_total_feature = None


        if lamb is None:
            if self.rank == 0:
                generator = np.random.RandomState(0)
                init_vec = generator.normal(size=n_total_feature)
                init_vec = init_vec / LA.norm(init_vec)
            else:
                generator = None
                init_vec = None

            local_feature = self.feature_transform(process, feature, ite)
            pre_feature = None
            eig_value, eig_vector, comm_cost = self.Lanczos_full_reorthogonal_vanilla(process, local_feature,
                                                                                      pre_feature, n_feature_set,
                                                                                      init_vec, tol, comm,
                                                                                      n_eigvectors=dim)

            return eig_value, eig_vector, comm_cost

        else:
            local_feature = self.feature_transform(process, feature, ite)
            if self.rank == 0:
                pre_feature = np.sqrt(1.0 - (1. / ite)) * np.matmul(pre_eigvector,
                                                                    np.diag(np.sqrt(pre_eigvalue - (1. / ite) * lamb)))
            else:
                pre_feature = None

            if self.rank == 0:
                generator = np.random.RandomState(0)
                init_vec = generator.normal(size=n_total_feature)
                init_vec = init_vec / LA.norm(init_vec)
            else:
                generator = None
                init_vec = None

            eig_value, eig_vector, comm_cost = self.Lanczos_full_reorthogonal_vanilla(process, local_feature,
                                                                                      pre_feature, n_feature_set,
                                                                                      init_vec, tol, comm,
                                                                                      lamb=lamb)


            return eig_value, eig_vector, comm_cost

    # def DSPGD_update_oracle(self, process, feature, ite, dim, comm, lamb=None, pre_eigvals=None, pre_eigvecs=None):
    #     if lamb is None:
    #         if self.rank == 0:
    #             generator = np.random.RandomState(0)
    #             init_vec = generator.normal(size=2 * self.n_components)
    #             init_vec = init_vec / LA.norm(init_vec)
    #
    #             seed = np.random.randint(65536, size=1)
    #         else:
    #             generator = None
    #             init_vec = None
    #
    #         init_vec = comm.bcast(init_vec, root=0)
    #         local_feature = self.feature_transform(process, feature, ite)
    #
    #         if self.rank == 0:
    #             local_mat = None
    #         else:
    #             local_mat = np.matmul(local_feature.T, local_feature)
    #
    #         local_mat_set = comm.gather(local_mat, root=0)
    #         if self.rank == 0:
    #             central_mat = process.aggregate(local_mat_set[1:])
    #             U, S, V = LA.svd(central_mat, full_matrices=False)
    #
    #             eigvecs_old, eigvals_old, comm_cost = self.oracle_DPM(central_mat, dim, 1.0)
    #
    #             # eigvals_old = S[:dim]
    #             # eigvecs_old = U[:,:dim]
    #         else:
    #             eigvals_old = None
    #             eigvecs_old = None
    #             comm_cost = None
    #
    #         eigvals_old = comm.bcast(eigvals_old, root=0)
    #         eigvecs_old = comm.bcast(eigvecs_old, root=0)
    #
    #         if self.rank == 0:
    #             local_eigvecs = None
    #         else:
    #             local_eigvecs = np.matmul(np.matmul(local_feature, eigvecs_old), np.diag(1/np.sqrt(eigvals_old)))
    #
    #         return eigvals_old, local_eigvecs, comm_cost
    #     else:
    #
    #         local_feature = self.feature_transform(process, feature, ite, lamb, pre_eigvecs, pre_eigvals)
    #
    #         if self.rank == 0:
    #             local_mat = None
    #         else:
    #             local_mat = np.matmul(local_feature.T, local_feature)
    #
    #         local_mat_set = comm.gather(local_mat, root=0)
    #         if self.rank == 0:
    #             central_mat = process.aggregate(local_mat_set[1:])
    #             U, S, V = LA.svd(central_mat, full_matrices=False)
    #             dim = np.sum(S > lamb)
    #
    #             eigvecs_old, eigvals_old, comm_cost = self.oracle_DPM(central_mat, dim, 1e-1)
    #
    #             # eigvals_old = S[:dim]
    #             # eigvecs_old = U[:, :dim]
    #         else:
    #             eigvals_old = None
    #             eigvecs_old = None
    #             comm_cost = None
    #
    #         eigvals_old = comm.bcast(eigvals_old, root=0)
    #         eigvecs_old = comm.bcast(eigvecs_old, root=0)
    #
    #         if self.rank == 0:
    #             local_eigvecs = None
    #         else:
    #             local_eigvecs = np.matmul(np.matmul(local_feature, eigvecs_old), np.diag(1/np.sqrt(eigvals_old)))
    #
    #         return eigvals_old, local_eigvecs, comm_cost
    #
    #
    # def DSPGD_update_wo_CE(self, process, feature, ite, dim, comm, lamb=None, pre_eigvals=None, pre_eigvecs=None, vec_size=None):
    #     if lamb is None:
    #         if self.rank == 0:
    #             generator = np.random.RandomState(0)
    #             init_vec = generator.normal(size=self.n_components)
    #             init_vec = init_vec / LA.norm(init_vec)
    #
    #         else:
    #             generator = None
    #             init_vec = None
    #
    #         init_vec = comm.bcast(init_vec, root=0)
    #         local_feature = self.feature_transform(process, feature, ite)
    #
    #         eigval, eigvec = self.first_eigvec(process, local_feature, init_vec, 1e-2, comm)
    #
    #         if self.rank == 0:
    #             # print('first eigval:', eigval, flush=True)
    #             eigvals_old = np.array([eigval])
    #             eigvecs_old = eigvec.reshape((-1, 1))
    #
    #             init_vec = generator.normal(size=self.n_components)
    #
    #             init_vec = init_vec - np.matmul(np.matmul(eigvecs_old, eigvecs_old.T), init_vec)
    #             init_vec = init_vec / LA.norm(init_vec)
    #
    #         else:
    #             eigvals_old = None
    #             eigvecs_old = None
    #             init_vec = None
    #         init_vec = comm.bcast(init_vec, root=0)
    #
    #         eigval, eigvec = self.second_eigvec(process, local_feature, eigvec, eigval, init_vec, 1e-2, comm)
    #         if self.rank == 0:
    #             # print('second eigval:', eigval, flush=True)
    #             eigvals_old = np.append(eigvals_old, eigval)
    #             eigvecs_old = np.concatenate((eigvecs_old, eigvec.reshape((-1, 1))), axis=1)
    #
    #             init_vec = generator.normal(size=self.n_components)
    #             init_vec = init_vec - np.matmul(np.matmul(eigvecs_old, eigvecs_old.T), init_vec)
    #             init_vec = init_vec / LA.norm(init_vec)
    #         else:
    #             eigvec = None
    #             eigvals_old = None
    #             eigvecs_old = None
    #
    #         for i in range(2, dim):
    #             init_vec = comm.bcast(init_vec, root=0)
    #             eigval, eigvec = self.rest_eigvec(process, local_feature, eigvals_old,
    #                                               eigvecs_old, init_vec, 1e-2, comm)
    #             if self.rank == 0:
    #                 # print('rest eigvals:', eigval, flush=True)
    #                 eigvals_old = np.append(eigvals_old, eigval)
    #                 eigvecs_old = np.concatenate((eigvecs_old, eigvec.reshape((-1, 1))), axis=1)
    #
    #                 init_vec = generator.normal(size=self.n_components)
    #                 init_vec = init_vec - np.matmul(np.matmul(eigvecs_old, eigvecs_old.T), init_vec)
    #                 init_vec = init_vec / LA.norm(init_vec)
    #             else:
    #                 eigvec = None
    #                 eigvals_old = None
    #                 eigvecs_old = None
    #
    #         eigvals_old = comm.bcast(eigvals_old, root=0)
    #         eigvecs_old = comm.bcast(eigvecs_old, root=0)
    #
    #         if self.rank == 0:
    #             local_eigvecs = None
    #         else:
    #             local_eigvecs = np.matmul(np.matmul(local_feature, eigvecs_old), np.diag(1/np.sqrt(eigvals_old)))
    #
    #         return eigvals_old, local_eigvecs
    #     else:
    #         local_feature = self.feature_transform(process, feature, ite, lamb, pre_eigvecs, pre_eigvals)
    #         iteration_flag = True
    #         if self.rank == 0:
    #             generator = np.random.RandomState(0)
    #             init_vec = generator.normal(size=vec_size)
    #             init_vec = init_vec / LA.norm(init_vec)
    #         else:
    #             generator = None
    #             init_vec = None
    #
    #         init_vec = comm.bcast(init_vec, root=0)
    #
    #         eigval, eigvec = self.first_eigvec(process, local_feature, init_vec, 1e-1, comm)
    #
    #         if self.rank == 0:
    #             # print('first eigval:', eigval, flush=True)
    #             eigvals_old = np.array([eigval])
    #             eigvecs_old = eigvec.reshape((-1, 1))
    #
    #             init_vec = generator.normal(size=vec_size)
    #
    #             init_vec = init_vec - np.matmul(np.matmul(eigvecs_old, eigvecs_old.T), init_vec)
    #             init_vec = init_vec / LA.norm(init_vec)
    #
    #         else:
    #             eigvals_old = None
    #             eigvecs_old = None
    #             init_vec = None
    #         init_vec = comm.bcast(init_vec, root=0)
    #
    #         eigval, eigvec = self.second_eigvec(process, local_feature, eigvec, eigval, init_vec, 1e-2, comm)
    #         if self.rank == 0:
    #             # print('second eigval:', eigval, flush=True)
    #             eigvals_old = np.append(eigvals_old, eigval)
    #             eigvecs_old = np.concatenate((eigvecs_old, eigvec.reshape((-1, 1))), axis=1)
    #
    #             init_vec = generator.normal(size=vec_size)
    #             init_vec = init_vec - np.matmul(np.matmul(eigvecs_old, eigvecs_old.T), init_vec)
    #             init_vec = init_vec / LA.norm(init_vec)
    #         else:
    #             eigvec = None
    #             eigvals_old = None
    #             eigvecs_old = None
    #
    #         while 1:
    #             iteration_flag = comm.bcast(iteration_flag, root=0)
    #             if not iteration_flag:
    #                 break
    #             init_vec = comm.bcast(init_vec, root=0)
    #             eigval, eigvec = self.rest_eigvec(process, local_feature, eigvals_old,
    #                                               eigvecs_old, init_vec, 1e-2, comm)
    #             if self.rank == 0:
    #                 # print('rest eigvals:', eigval, flush=True)
    #                 if eigval > lamb:
    #                     eigvals_old = np.append(eigvals_old, eigval)
    #                     eigvecs_old = np.concatenate((eigvecs_old, eigvec.reshape((-1, 1))), axis=1)
    #
    #                     init_vec = generator.normal(size=vec_size)
    #                     init_vec = init_vec - np.matmul(np.matmul(eigvecs_old, eigvecs_old.T), init_vec)
    #                     init_vec = init_vec / LA.norm(init_vec)
    #                 else:
    #                     iteration_flag = False
    #             else:
    #                 eigvec = None
    #                 eigvals_old = None
    #                 eigvecs_old = None
    #
    #         eigvals_old = comm.bcast(eigvals_old, root=0)
    #         eigvecs_old = comm.bcast(eigvecs_old, root=0)
    #
    #         if self.rank == 0:
    #             local_eigvecs = None
    #         else:
    #             local_eigvecs = np.matmul(np.matmul(local_feature, eigvecs_old), np.diag(1/np.sqrt(eigvals_old)))
    #
    #         return eigvals_old, local_eigvecs
    #
    #
    # def oracle_DPM(self, matrix, n_vec, tol):
    #     n_dim = matrix.shape[1]
    #     generator = np.random.RandomState(0)
    #     eigvecs = generator.normal(size=(n_dim, n_vec))
    #     eigvals_old = np.zeros(n_vec)
    #
    #     update_eigvecs = np.matmul(matrix, eigvecs)
    #
    #     U, _, _ = LA.svd(update_eigvecs, full_matrices=False)
    #
    #     eigvecs = U[:,:n_vec]
    #
    #     eigval_mat = np.matmul(eigvecs.T, np.matmul(matrix, eigvecs))
    #     eigvals = np.diagonal(eigval_mat)
    #
    #     ite = 1
    #
    #     while np.sum(np.abs(eigvals - eigvals_old) < tol) < n_vec:
    #         eigvals_old = np.copy(eigvals)
    #         update_eigvecs = np.matmul(matrix, eigvecs)
    #
    #         U, _, _ = LA.svd(update_eigvecs, full_matrices=False)
    #
    #         eigvecs = U[:, :n_vec]
    #
    #         eigval_mat = np.matmul(eigvecs.T, np.matmul(matrix, eigvecs))
    #         eigvals = np.diagonal(eigval_mat)
    #         ite = ite + 1
    #
    #     comm_cost = ite * n_dim * n_vec
    #
    #     return eigvecs, eigvals, comm_cost




