import random
import numpy as np
from numpy.linalg import eig


# dotProductOracle in Spectral Clustering Oracles in Sublinear Time
# to estimate the dot product of f_x and f_y


class dotProductOracle:
    def __init__(self, G, k, R_init, R_query, t, s):
        self.num = 20  # median trick

        self.G = G
        self.k = k
        self.R_init = R_init
        self.R_query = R_query
        self.t = t
        self.s = s
        self.N = len(G.nodes())

        self.degrees = dict(G.degree())
        self.max_degree = max(self.degrees.values())
        self.neighbor = []
        for i in range(0, self.N):
            self.neighbor.append(list(self.G.neighbors(i)))

        self.structure = self.InitializeOracle()


    def RunRandomWalks(self, R, t, x):  # run R times lazy random walks with t steps from x
        # generate an all-0 array of length N
        m_x = []
        for i in range(0, self.N):
            m_x.append(0)

        # run R times
        for i in range(0, R):  # R times
            curr_vtx = x
            for j in range(0, t):  # t steps
                deg = self.degrees[curr_vtx]
                temp = random.random()

                if temp <= deg / (2 * self.max_degree):  # jump to a random neighbor
                    curr_vtx = random.choice(self.neighbor[curr_vtx])
                else:  # stay at curr_vtx
                    curr_vtx = curr_vtx
            m_x[curr_vtx] += 1

        # get m_x
        for i in range(0, self.N):
            m_x[i] = m_x[i] / R
        return m_x


    def EstimateTransitionMatrix(self, S, R, t):
        Q = []
        for i in range(0, len(S)):
            x = S[i]
            m_x = self.RunRandomWalks(R, t, x)
            Q.append(m_x)

        return np.array(Q).T  # size: n×s


    def EstimateCollisionMatrix(self, S, R, t):
        bigG_list = []

        # median trick
        for i in range(self.num):
            Q = self.EstimateTransitionMatrix(S, R, t)  #  size: n×s
            P = self.EstimateTransitionMatrix(S, R, t)  #  size: n×s
            bigG = (np.dot(P.T, Q) + np.dot(Q.T, P)) / 2
            bigG_list.append(bigG)
        median_G = np.median(bigG_list, axis=0)
        return median_G  #s×s


    def InitializeOracle(self):
        # get a random sampling set I_S of length s
        I_S = []
        for i in range(0, self.s):
            I_S.append(int((self.N * random.random()) % self.N))

        # median trick
        # get self.num Q's
        Q_list = []
        for i in range(self.num):
            Q = self.EstimateTransitionMatrix(I_S, self.R_init, self.t)  #  size: n×s
            Q_list.append(Q)

        # get mathcal_G
        mathcal_G = self.EstimateCollisionMatrix(I_S, self.R_init, self.t)

        # eigen-decomposition
        mathcal_G = np.dot(mathcal_G, self.N / self.s)
        vals, vecs = eig(mathcal_G)
        Lambda = np.diag(vals)

        # check if Lambda is invertible
        flag = True
        for i in range(0, self.s):
            if Lambda[i][i] == 0:
                flag = False
                break
        # Lambda is invertible
        if flag:
            # get Lambda^{-2}
            Lambda_2 = np.zeros((self.s, self.s))
            for i in range(0, self.s):
                Lambda_2[i][i] = 1 / (Lambda[i][i] * Lambda[i][i])

            # get psi
            psi = (self.N / self.s) * \
                  (np.dot(np.dot(vecs[:, :self.k], Lambda_2[:self.k, :self.k]), vecs[:, :self.k].T))
            structure = [psi, Q_list]
            return structure
        return None


    def SpectralDotProductOracle(self, x, y):
        temp_list = []
        for i in range(self.num):
            m_x = self.RunRandomWalks(self.R_query, self.t, x)
            temp_list.append(np.dot(self.structure[1][i].T, m_x))
        alpha_x = np.median(temp_list, axis=0)

        temp_list = []
        for i in range(self.num):
            m_y = self.RunRandomWalks(self.R_query, self.t, y)
            temp_list.append(np.dot(self.structure[1][i].T, m_y))
        alpha_y = np.median(temp_list, axis=0)

        # note that alpha_x and alpha_y is a list
        return np.dot(np.dot(alpha_x.T, self.structure[0]), alpha_y)
