import numpy as np
from numpy import linalg as LA
from numpy.linalg import matrix_rank, pinv
from sklearn.decomposition import TruncatedSVD

class Coordinator(object):
    def __init__(self, n_nodes, max_rank):
        self.n_nodes = n_nodes
        self.max_rank = max_rank
        self.set0 = None
        self.set1 = None
        self.coefficient = None

    def aggregate(self, data):
        result = np.copy(data[0])

        for i in range(1, self.n_nodes):
            result = result + data[i]

        return result

    def check_indepdence(self, vec, tol=1e-5):
        coefficient = np.matmul(pinv(self.set0), vec)
        # if not np.isscalar(coefficient):
        #     est_vec = np.matmul(self.set0, coefficient)
        # else:
        #     est_vec = self.set0 * coefficient
        est_vec = np.matmul(self.set0, coefficient)

        rank = matrix_rank(self.set0)
        X = np.concatenate((self.set0, vec.reshape((-1, 1))), axis=1)

        rank_X = matrix_rank(X)
        err = LA.norm(est_vec - vec) / np.sqrt(vec.shape[0])
        if rank_X <= rank or rank >= self.max_rank:
            state = False
        else:
            state = True

        return state, coefficient

    def check_indepdence_new(self, vec):
        X = np.concatenate((self.set0, vec.reshape((-1, 1))), axis=1)

        rank = LA.matrix_rank(X)
        # print(S)
        if rank == self.set0.shape[1] or rank == self.max_rank:
            state = False
        else:
            state = True

        return state


    def coeffcient(self, vec):
        U, S, V = LA.svd(self.set0, full_matrices=False)

        if U.shape[1] != self.set0.shape[1]:
            print('U error, please debug')

        if V.shape[0] != V.shape[1]:
            print('V error, please debug')

        projection = np.matmul(V.T, np.matmul(np.diag(1/S), U.T))
        coefficient = np.matmul(projection, vec)

        print(vec)
        print(np.abs(vec-np.matmul(self.set0, coefficient)))

        return coefficient

    def early_stopping(self, eigvecs_old, eigvals_old, tol, dim, lamb=None):
        pinv_mat = LA.pinv(self.set0)

        low_rank = np.matmul(self.set1, pinv_mat)
        # svd = TruncatedSVD(dim, n_iter=20)
        # svd.fit(low_rank)
        #
        # eigvals = svd.singular_values_
        # eigvecs = svd.components_.T

        U, S, _ = LA.svd(low_rank, full_matrices=False)
        if lamb is not None:
            dim = np.sum(S >= lamb)

        eigvals = S[:dim]
        eigvecs = U[:, :dim]

        print(eigvals, flush=True)

        # est_proj_old = np.matmul(np.matmul(eigvecs_old, np.diag(eigvals_old)), eigvecs_old.T)
        # ave = LA.norm(est_proj_old) / est_proj_old.shape[0]
        #
        # est_proj_new = np.matmul(np.matmul(eigvecs, np.diag(eigvals)), eigvecs.T)

        # diff = LA.norm(est_proj_old - est_proj_new) / est_proj_old.shape[0]
        # print('diff: ', diff)
        # if diff < ave*tol or self.set0.shape[0] <= self.set0.shape[1]:
        #     stop_flag = True
        # else:
        #     stop_flag = False

        if dim == eigvals_old.shape[0]:
            delta = tol * eigvals_old
            diff = np.abs(eigvals - eigvals_old)
            if np.sum(diff <= delta) == dim:
                converge_flag = True
            else:
                converge_flag = False
        else:
            converge_flag = False
        if converge_flag or self.set0.shape[0] <= self.set0.shape[1]:
            stop_flag = True
        else:
            stop_flag = False

        return stop_flag, eigvecs, eigvals


    # def check_indepdence(self, vec, tol=1e-5):
    #     # coefficient = np.matmul(pinv(self.set0), vec)
    #     # if not np.isscalar(coefficient):
    #     #     est_vec = np.matmul(self.set0, coefficient)
    #     # else:
    #     #     est_vec = self.set0 * coefficient
    #     coefficient = np.matmul(self.set0.T, vec)
    #
    #     est_vec = np.matmul(self.set0, coefficient)
    #
    #     rank = matrix_rank(self.set0)
    #     X = np.concatenate((self.set0, vec.reshape((-1, 1))), axis=1)
    #
    #     rank_X = matrix_rank(X)
    #     err = LA.norm(est_vec - vec) / np.sqrt(vec.shape[0])
    #     if err < tol or rank_X <= rank or rank >= self.max_rank:
    #         state = False
    #     else:
    #         state = True
    #
    #     return state, coefficient

    def init_set(self, vec0, vec1):
        self.set0 = vec0.reshape((-1, 1))
        self.set1 = vec1.reshape((-1, 1))

    def add_element(self, vec0, vec1):
        self.set0 = np.concatenate((self.set0, vec0.reshape((-1, 1))), axis=1)
        self.set1 = np.concatenate((self.set1, vec1.reshape((-1, 1))), axis=1)

    def update_coefficient(self, coefficient):
        self.coefficient = np.copy(coefficient)

    def frist_eigval(self, vec0, vec1):
        eigval = np.dot(vec0, vec1)
        return eigval

    def rest_eigvals(self, vec0, vec1, exvec0, exvec1):
        eigval_square = np.dot(vec0, vec1)+np.dot(exvec0, exvec1)
        if eigval_square >= 0:
            eigval = np.sqrt(eigval_square)
            eigval_flag = True
        else:
            eigval = None
            eigval_flag = False
        return eigval, eigval_flag

    def central_update(self, coefficient):
        set1_size = self.set1.shape[1]
        coefficient_size = coefficient.shape[0]

        if coefficient_size == set1_size:
            update_vec = np.matmul(self.set1, coefficient)
        else:
            update_vec = np.matmul(self.set1[:, :coefficient_size], coefficient)
        return update_vec

    def remove_part(self, vec, eigval, eigvec):
        if eigval.shape[0] == 1:
            eigspace = np.matmul(eigvec.reshape((-1, 1)), eigvec.reshape((1, -1)))
            minus_vec = np.matmul(eigspace, vec)
        else:
            eigspace = np.matmul(eigvec, eigvec.T)
            minus_vec = np.matmul(eigspace, vec)
        return minus_vec

    def first_vec_update(self, state, data, coefficient, vec_dn):
        if state:
            vec_d = self.aggregate(data)
        else:
            vec_d = self.central_update(coefficient)

        return vec_d


    def rest_vecs_update(self, state, data, coefficient, vec_dn, vec_en, pre_vecs0, pre_vecs1):
        if state:
            vec_d0 = self.aggregate(data)
        else:
            vec_d0 = self.central_update(coefficient)
        is_scalar = np.isscalar(vec_en)
        if is_scalar:
            vec_d = vec_d0 - vec_en * pre_vecs1
            vec_e = np.dot(pre_vecs1, vec_dn) - np.dot(pre_vecs0, pre_vecs1)
        else:
            vec_d = vec_d0 - np.matmul(pre_vecs1, vec_en)
            vec_e = np.matmul(pre_vecs1.T, vec_dn) - np.matmul(pre_vecs0.T, pre_vecs1)

        return vec_d, vec_e

    def concatenation(self, data):
        result = np.copy(data[0])

        for i in range(1, self.n_nodes):
            result = np.concatenate((result, data[i]), axis=0)

        return result


