import clusteringOracle
import networkx as nx
import cps15clusteringOracle
from networkx.algorithms import bipartite
import time


# calculate the mis-classification error of our oracle and [CPS15]


def defineRandomList(random_list):
    global my_random_list
    my_random_list = random_list


def ourAlgorithm(G, k, R_init, R_query, t, s, s_dot, theta, timeFlag):
    oracle = clusteringOracle.ClusteringOracle(G, k, R_init, R_query, t, s, s_dot, theta, timeFlag)
    inCluster = []

    if timeFlag:  # record query time
        file = open("./Results/queryTime.txt", "a")
        sum = 0
        for i in range(len(my_random_list)):
            # every 50 print something
            if i % 50 == 0:
                print(i)

            u = my_random_list[i]

            start_time = time.time()
            index = oracle.whichCluster(u)
            end_time = time.time()
            sum = sum + (end_time - start_time)

            inCluster.append(index)
        avg_query_time = sum / len(my_random_list)
        file.write("average query time of " + str(len(my_random_list)) + " queries is: " +
                   str(avg_query_time) + "\n\n\n")
        file.close()
        return inCluster
    else:
        for i in range(len(my_random_list)):
            # every 50 print something
            if i % 50 == 0:
                print(i)

            u = my_random_list[i]
            index = oracle.whichCluster(u)
            inCluster.append(index)
        return inCluster




def cps15Algorithm(G, s, k, t, r, theta):
    # oracle = clusteringOracle.ClusteringOracle(G, s, k, t, r, theta)
    oracle = cps15clusteringOracle.ClusteringOracle(G, s, k, t, r, theta)

    inCluster = []
    for i in range(len(my_random_list)):
        # every 50 print something
        if i % 50 == 0:
            print(i)

        u = my_random_list[i]
        index = oracle.whichCluster(u)
        inCluster.append(index)
    return inCluster

# inCluster is a list, inCluster[i] = cluster index of i, cluster index is in [0,k-1]
def toClusteringSets(inCluster):
    clusterNum = max(inCluster) + 1
    ClusteringSets = [set() for _ in range(clusterNum)]
    # ClusteringSets is a list, it has clusterNum entries, every entry is an empty set

    for u in range(len(inCluster)):
        ClusteringSets[inCluster[u]].add(my_random_list[u])

    # ClusteringSets is a list, has clusterNum entries, every entry is a set
    # and ClusteringSets[i] is a set, it contains the nodes that belong to cluster i
    return ClusteringSets


def getMatching(clustering, plantedClusters):
    k = len(plantedClusters)
    N = 0
    for x in clustering:
        N += len(x)

    bigraph = nx.Graph()
    leftNodes = [f'l{i}' for i in range(k)]
    rightNodes = [f'r{j}' for j in range(len(clustering))]
    bigraph.add_nodes_from(leftNodes, bipartite=0)
    bigraph.add_nodes_from(rightNodes, bipartite=1)
    for i in range(k):
        A = plantedClusters[i]
        for j in range(len(clustering)):
            cluster = clustering[j]
            B = set(cluster)
            weight = -len(A.intersection(B))
            bigraph.add_edge(f'l{i}', f'r{j}', weight=weight)

    matching = bipartite.matching.minimum_weight_full_matching(bigraph, leftNodes, 'weight')
    print(matching)

    value = 0
    for li in leftNodes:
        rj = matching[li]

        i = int(li[1:])
        j = int(rj[1:])

        A = plantedClusters[i]
        B = set(clustering[j])

        value += len(A.intersection(B))
        print(f'	got accuracy {len(A.intersection(B)) / len(A)} for cluster with {len(A)} vertices')
    return value / N
