import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
import itertools
from scipy.linalg import subspace_angles



class SBM2:
    def __init__(self, k, n, a, b, m):
        self.k = k
        self.n = n
        self.a = float(a)
        self.b = float(b)
        self.m = m

    def run(self, do_deletion=True, do_red_blue=False, do_naive=False, average_u1=False, average_perp_vector=False):
        A_array, threshold = self.generate_model()

        if average_u1 == False:
            A_sum = np.sum(x for x in A_array)
            A = A_sum / self.m
            A_orig = A
            A, threshold_reordered, correct_partition, this_reordering = self.reorder_matrix(A, threshold, None)

            deviation = A_orig - threshold
            deviation_norm = np.linalg.norm(deviation, ord=2)

            U1,U2 = self.get_eigenvectors_from_A(A)
            w1,w2 = self.get_eigenvectors_from_A(threshold_reordered)

        else:
            U1_array = []
            U2_array = []
            for A in A_array:
                #reordering is done ONCE and we feed it to all the samples
                reordering = np.random.permutation(self.n * self.k)
                A, threshold_reordered, correct_partition, this_reordering = self.reorder_matrix(A=A, threshold=threshold, reordering=reordering)
                U1, U2 = self.get_eigenvectors_from_A(A)
                w1=U1
                w2=U2
                deviation_norm = 0

                U1_array.append(U1)
                U2_array.append(U2)

            U1 = np.sum(x for x in U1_array) / self.m
            U2 = np.sum(x for x in U2_array) / self.m


        angles = subspace_angles(np.transpose([U1, U2]), np.transpose([w1, w2]))
        sin_angle = np.sin(angles[0])

        if do_red_blue:
            red, blue = self.split_red_blue_edges(A)
        else:
            red = A

        [partition, perp_vector] = self.spectral_partition(red, do_deletion=do_deletion, do_naive=do_naive, U1=U1, U2=U2)

        correct = len(set(partition[1]).intersection(correct_partition[1]))
        correct_fraction = correct/self.n
        correct_fraction = max(correct_fraction, 1 - correct_fraction)
        gamma = 1-correct_fraction

        vec_angle = np.arccos(np.dot(perp_vector, w2) / np.linalg.norm(perp_vector) / np.linalg.norm(w2))
        sin_vector = np.sin(vec_angle)

        return deviation_norm, sin_angle, gamma, sin_vector

    def get_eigenvectors_from_A(self, A):

        eigenvalues_A, eigenvectors_A = np.linalg.eig(A)

        # get 2 largest values. Index 0 is the largest EV, index 1 is the second largest
        U = eigenvectors_A[:, [0, 1]]

        U1 = U[:, 0]
        U2 = U[:, 1]

        return U1,U2


    def measure_deviation_matrix(self):
        A_array, threshold = self.generate_model()

        A_sum = np.sum(x for x in A_array)
        A = A_sum / self.m

        A, correct_partition = self.reorder_matrix(A)

        deviation = A - threshold
        deviation_norm = np.linalg.norm(deviation, ord=2)

        eigenvalues_A, eigenvectors_A = np.linalg.eig(A)
        eigenvalues_exp, eigenvectors_exp = np.linalg.eig(threshold)

        # get 2 largest values. Index 0 is the largest EV, index 1 is the second largest
        U = eigenvectors_A[:, [0, 1]]
        w = eigenvectors_exp[:, [0, 1]]

        U1 = U[:, 0]
        U2 = U[:, 1]
        w1 = w[:, 0]
        w2 = w[:, 1]

        angles = subspace_angles(np.transpose([U1, U2]), np.transpose([w1, w2]))
        sin_angle = np.sin(angles[0])

        return deviation_norm, sin_angle



    def generate_model(self):
        A_array = []
        for i in range(self.m):
            this_A, threshold = self.generate_model_single()
            A_array.append(this_A)

        return A_array, threshold

    def generate_model_single(self):
        # initialize sparse matrix
        A = np.triu(np.random.rand(self.n * self.k, self.n * self.k))
        original_A = A.copy()
        threshold = self.b / self.n * np.ones((self.n * self.k, self.n * self.k))
        for i in range(self.k):
            threshold[i * self.n:(i + 1) * self.n, i * self.n:(i + 1) * self.n] = self.a / self.n

        A[A >= threshold] = 0
        A[A > 0] = 1

        # ensure symmetry - construct adjacency matrix
        A = A + A.T
        A[A > 0] = 1
        return A, threshold

    def split_red_blue_edges(self, A):
        # find edge idxs
        ai, aj = np.where(np.triu(A) > 0)
        m = len(ai)
        # creating a random splitting of edges
        edges = np.random.rand(m, 1)
        edges1 = np.where(edges <= .5)[0]
        edges2 = np.where(edges > .5)[0]

        A = np.zeros((self.n * self.k, self.n * self.k))
        for x, y in zip(ai[edges1], aj[edges1]):
            A[x, y] = 1

        B = np.zeros((self.n * self.k, self.n * self.k))
        for x, y in zip(ai[edges2], aj[edges2]):
            B[x, y] = 1

        # construct adjacency matrix over the random split
        A = A + A.T
        B = B + B.T
        return A, B

    def reorder_matrix(self, A, threshold, reordering):
        if reordering is None:
            reordering = np.random.permutation(self.n * self.k)

        result = A[:, reordering]
        result = result[reordering, :]

        threshold_result = threshold[:, reordering]
        threshold_result = threshold_result[reordering, :]

        inverse_reordering = np.argsort(reordering)
        correct_partition = [inverse_reordering[0:self.n], inverse_reordering[self.n:]]
        return result, threshold_result, correct_partition, reordering

    def create_bipartite_graph(self):
        # Randomly select about half indices in range(k*n)
        r = np.random.rand(self.k * self.n, 1)
        Y = np.where(r > .5)[0]
        Z = np.where(r <= .5)[0]
        return (Y, Z)

    def spectral_partition(self, A, do_deletion=False, do_naive=False, U1=None, U2=None):

        #d = self.a + self.b
        # for fairness, we're not supposed to know a or b, but need to estimate a+b
        d = A.sum() / self.n / 2
        if do_deletion:
            row_degrees = np.sum(A, axis=1)
            col_degrees = np.sum(A, axis=0)

            bad_rows = np.where(row_degrees > 20 * d)[0]
            bad_cols = np.where(col_degrees > 20 * d)[0]
            if len(bad_rows) > 0:
                A[bad_rows, :] = np.zeros(A.shape[1])
                print("deleting " + str(len(bad_rows)) + " bad rows")

            if len(bad_cols) > 0:
                A[:, bad_cols] = np.zeros((0, A.shape[0]))
                print("deleting " + str(len(bad_cols)) + " bad cols")

        if U1 is None:

            # singular value decomposition of A1
            #U, S, V = np.linalg.svd(A)
            eigenvalues, eigenvectors = np.linalg.eig(A)

            # get 2 largest values. Index 0 is the largest EV, index 1 is the second largest
            U = eigenvectors[:, [0, 1]]

            U1 = U[:, 0]
            U2 = U[:, 1]

        else:
            U = np.vstack((U1,U2)).T

        all_ones = np.ones([self.n*2, 1])


        # Projection of A2 onto singular values U of A1
        proj_v1 = np.dot(np.dot(U, U.T), all_ones)
        v1 = np.dot(U1, proj_v1)
        v2 = np.dot(U2, proj_v1)

        do_plot = False
        if do_plot:
            plt.bar(range(100), eigenvalues[0:100])
            plt.show()
            UU1 = U1 * np.sqrt(self.n * self.k)
            UU2 = U2 * np.sqrt(self.n * self.k)
            plt.bar(range(self.n * self.k), UU1)
            plt.show()
            plt.bar(range(self.n * self.k), UU2)
            plt.show()

        if do_naive:
            U2_args = np.argsort(U2)
            partition_1 = U2_args[0:self.n]
            partition_2 = U2_args[self.n:]

        else:
            # They can't be both zeroes because the eigen vectors then are degenerate
            if (v1==0) and (v2==0):
                raise("The eigen vectors are degenerate. Subspace W only has rank 1")

            # We want t1 and t2 such that t1v1 + t2v2 = 0, t1, t2 not both zero
            if (v1==0):
                t2 = 0
                t1 = 1
            elif (v2==0):
                t1 = 0
                t2 = 1
            else:
                t1 = -v2
                t2 = v1

            assert(t1*v1 + t2*v2 == 0)

            # Vector perpendicular to proj_v1 but belonging to W is t1U1 + t2U2
            perp_v1 = t1*U1 + t2*U2
            perp_v1_args = np.argsort(perp_v1)
            partition_1 = perp_v1_args[0:self.n]
            partition_2 = perp_v1_args[self.n:]

        return [[partition_1, partition_2], perp_v1]

    def plot_output(self, A, A_shuffled, recovered):
        f = plt.figure("original blocked")
        plt.imshow(A);
        plt.colorbar()

        f = plt.figure("shuffled input")
        plt.imshow(A_shuffled);
        plt.colorbar()
        f = plt.figure("recovered output")
        plt.imshow(recovered);
        plt.colorbar()
        plt.show()

    def merge(self, outCxs, A, Y, blue):
        dZZ = 1.5 * (self.a + self.b) / 4
        bad_xy = {}
        for block1 in range(self.k):
            for block2 in range(block1 + 1, self.k):
                Cx = A[outCxs[block2], :][:, outCxs[block1]]
                degZ = np.sum(Cx, axis=0)
                bad12 = outCxs[block1][np.where(degZ > dZZ)[0]]
                bad_xy[(block1, block2)] = bad12
                degZ = np.sum(Cx, axis=1)
                bad21 = outCxs[block2][np.where(degZ > dZZ)[0]]
                bad_xy[(block2, block1)] = bad21

        newOutCxs = []
        for block1 in range(self.k):
            xy = []
            yx = []
            for block2 in range(self.k):
                if block2 == block1:
                    continue

                xy = np.union1d(xy, bad_xy[(block1, block2)])
                yx = np.union1d(yx, bad_xy[(block2, block1)])

            outCx = np.setdiff1d(outCxs[block1], xy).astype(int)
            outCx = np.union1d(outCx, yx).astype(int)
            newOutCxs.append(outCx)

        dxs = []

        for outCx in outCxs:
            C1 = blue[outCx, :][:, Y]
            dxs.append(np.sum(C1, axis=0))

        idY = np.argmax(dxs, axis=0)
        indices = np.array([])
        for i in range(len(outCxs)):
            outCx = np.union1d(outCxs[i], Y[np.where(idY == i)])
            # np.random.shuffle(outCx)
            indices = np.concatenate((indices, outCx))
        indices = indices.astype(int)
        return A[indices, :][:, indices]

    def correction(self, outCxs, A, Z):
        extra = Z
        for outCx in outCxs:
            extra = np.setdiff1d(extra, outCx).astype(int)
        dxs = []
        for outCx in outCxs:
            Cx = A[outCx, :][:, extra]
            dx = np.sum(Cx, axis=0)
            dxs.append(dx)

        idextra = np.argmax(dxs, axis=0)

        outCxs_list = []
        for i in range(len(outCxs)):
            outCx = np.union1d(outCxs[i], extra[np.where(idextra == i)[0]]).astype(int)
            outCxs_list.append(outCx)
        return outCxs_list


if __name__ == "__main__":
    s = SBM(3, 1000, 50, 5)
    s.run()

