import random
import networkx as nx



# the clustering oracle in [CPS15]



class ClusteringOracle:
    def __init__(self, G, s, k, t, r, theta):
        self.G = G
        self.s = s
        self.k = k
        self.t = t
        self.r = r
        self.theta = theta

        self.N = len(G.nodes())
        self.degrees = dict(G.degree())
        self.max_degree = max(self.degrees.values())
        self.neighbor = []
        for i in range(0, self.N):
            self.neighbor.append(list(self.G.neighbors(i)))


        self.structure = self.constructOracle()
        while self.structure == "fail":
            print("Construct oracle fail!!!")
            self.structure = self.constructOracle()
        self.H = self.structure[0]
        self.end_list = self.structure[1]
        self.S = self.structure[2]

    def randomWalk(self, x, t):
        curr_vtx = x
        for i in range(0, t):
            deg = self.degrees[curr_vtx]

            temp = random.random()
            if temp <= deg / (2 * self.max_degree):  # jump to a random neighbor
                curr_vtx = random.choice(self.neighbor[curr_vtx])
            else:  # stay
                curr_vtx = curr_vtx

        return curr_vtx


    # construct similarity graph
    def constructOracle(self):
        # sample
        S = random.sample(list(self.G.nodes()), self.s)

        end_list = []
        for i in range(0, len(S)):  # for each vertex in S, we sample r p_vtx^t
            vtx = S[i]
            vtx_list = [0] * self.N
            for j in range(0, self.r):
                endpoint = self.randomWalk(vtx, self.t)
                vtx_list[endpoint] = vtx_list[endpoint] + 1
            end_list.append(vtx_list)


        # generate a similarity graph
        H = nx.Graph()

        # add nodes in sampling set
        for i in range(0, len(S)):
            u = S[i]
            H.add_node(u)

        # add edges
        for i in range(0, len(S)):
            u = S[i]
            for j in range(i + 1, len(S)):
                v = S[j]
                # estimate the l2-norm-square-distance
                apx = 0
                for index in range(0, self.N):
                    apx = apx + (end_list[i][index] - end_list[j][index]) * (end_list[i][index] - end_list[j][index]) \
                          - end_list[i][index] - end_list[j][index]
                apx = apx / (self.r * self.r)

                if apx <= self.theta:
                    H.add_edge(u, v)

        components = list(nx.connected_components(H))

        if len(components) != self.k:
            print("There are" + str(len(components)) + "components")
            print(components)
            return "fail"
        else:
            print("success")
            print(components)
            structure = [H, end_list, S]

            file = open("./Results/cps15error.txt", "a")
            file.write(str(components))
            file.write("\n")
            file.close()

            return structure

    def searchIndex(self, x):
        x_list = [0] * self.N
        for i in range(0, self.r):
            endpoint = self.randomWalk(x, self.t)
            x_list[endpoint] = x_list[endpoint] + 1

        components = list(nx.connected_components(self.H))

        # check each component in order
        for i in range(0, len(components)):
            flag = True
            for u in components[i]:
                location = self.S.index(u)
                # estimate the l2-norm-square-distance
                apx = 0
                for index in range(0, self.N):
                    apx = apx + (self.end_list[location][index] - x_list[index]) * (self.end_list[location][index] - x_list[index]) \
                          - self.end_list[location][index] - x_list[index]
                apx = apx / (self.r * self.r)

                if apx > self.theta:
                    flag = False
                    break
            if flag:
                return i
        return "outlier"

    def whichCluster(self, x):
        tmp = self.searchIndex(x)
        if tmp == "outlier":
            return random.randint(0, self.k - 1)  # return a random index in [0, k-1]
        else:
            return tmp

