import random
import networkx as nx
import dotProductOracle
import time


# algorithms of our oracle


class ClusteringOracle:
    def __init__(self, G, k, R_init, R_query, t, s, s_dot, theta, timeFlag):
        self.G = G
        self.k = k
        self.R_init = R_init
        self.R_query = R_query
        self.t = t
        self.s = s  # sample size of our clustering oracle
        self.theta = theta
        self.timeFlag = timeFlag

        # s_dot is used in dotProductOracle
        self.dotOracle = dotProductOracle.dotProductOracle(G, k, R_init, R_query, t, s_dot)

        if self.timeFlag:  # record time
            file = open("./Results/queryTime.txt", "a")

            start_time = time.time()
            self.H = self.constructOracle()

            while self.H == "fail":
                print("Construct oracle fail!!!")
                start_time = time.time()
                self.H = self.constructOracle()

            end_time = time.time()
            file.write("pre-processing time: " + str(end_time - start_time) + "\n")
            file.close()
        else:
            self.H = self.constructOracle()
            while self.H == "fail":
                print("Construct oracle fail!!!")
                self.H = self.constructOracle()



    # construct similarity graph
    def constructOracle(self):
        # sample
        S = random.sample(list(self.G.nodes()), self.s)

        # generate a similarity graph
        H = nx.Graph()

        # add nodes in sampling set
        for i in range(len(S)):
            u = S[i]
            H.add_node(u)

        # add edges
        for i in range(len(S)):
            u = S[i]
            for j in range(i + 1, len(S)):
                v = S[j]
                apx = self.dotOracle.SpectralDotProductOracle(u, v)
                if apx >= self.theta:
                    H.add_edge(u, v)

        # check if there are exactly k connected components
        components = list(nx.connected_components(H))
        if len(components) != self.k:
            print("There are " + str(len(components)) + " components")
            print(components)
            return "fail"
        else:
            return H


    def searchIndex(self, x):
        components = list(nx.connected_components(self.H))

        # check each component in order
        for i in range(len(components)):
            flag = True
            for u in components[i]:
                apx = self.dotOracle.SpectralDotProductOracle(u, x)
                if apx < self.theta:
                    flag = False
                    break
            if flag:
                return i
        return "outlier"


    def whichCluster(self, x):
        tmp = self.searchIndex(x)
        if tmp == "outlier":
            return random.randint(0, self.k - 1)  # return a random index in [0, k-1]
        else:
            return tmp