import misError
import syntheticData

import random
import pandas as pd
import networkx as nx


# eval the mis-error and query time of our oracle



# n_list = [1000]
# k_list = [3]
# p_list = [0.07, 0.06, 0.05, 0.04, 0.03, 0.0275, 0.025, 0.0225, 0.02]
# q_list = [0.002]
# theta_list = [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.00053, 0.0006]


# n_list = [1000]
# k_list = [3]
# p_list = [0.02, 0.025, 0.03, 0.035, 0.04, 0.05, 0.06, 0.07]
# q_list = [0.002]
# theta_list = [0.0006, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005]

n_list = [1000]
k_list = [3]
p_list = [0.02, 0.025, 0.03, 0.035, 0.04, 0.05, 0.06, 0.07]
q_list = [0.002]
theta_list = [0.0006, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005]

R_init = 2000
R_query = 200
t = 25
s_dot = 20
s = 21

# timeFlag = True means that we record the pre-processing time and average query time
timeFlag = True

# synthetic data
for n_index in range(0, len(n_list)):
    for k_index in range(0, len(k_list)):
        for p_index in range(0,len(p_list)):
            for q_index in range(0,len(q_list)):
                n = n_list[n_index]
                k = k_list[k_index]
                p = p_list[p_index]
                q = q_list[q_index]
                theta = theta_list[p_index]

                N = k * n
                syntheticData.SBM(n, k, p, q)
                data_path = "./SyntheticData/" + "n=" + str(n) + "_k=" + str(k) + \
                            "_p=" + str(p) + "_q=" + str(q) + ".csv"

                # write txt
                file = open("Results/error.txt", "a")
                file.write("n=" + str(n) + "_k=" + str(k) + "_p=" + str(p) + "_q=" + str(q) + "\n")
                file.write("theta=" + str(theta) + "\n")
                file.write("R_init=" + str(R_init) + "_R_query=" + str(R_query) +
                           "_t=" + str(t) + "_s_dot=" + str(s_dot) + "_s=" + str(s) + "\n")
                file.close()


                # read data
                E = pd.read_csv(data_path)

                # generate the graph
                G = nx.Graph()
                for i in range(0, N):
                    G.add_node(i)
                G.add_edges_from([(u, v) for _, u, v in E.itertuples()])

                # when selecting a sampling size of N, we calculate the overall accuracy
                random_list = random.sample(range(N), N)  #  random.sample(sequence, k)
                misError.defineRandomList(random_list)

                # ground_truth
                plantedClusters = misError.toClusteringSets([int(u / n) for u in random_list])

                # our
                file = open("./Results/queryTime.txt", "a")
                file.write("n=" + str(n) + "_k=" + str(k) + "_p=" + str(p) + "_q=" + str(q) + "\n")
                file.write("theta=" + str(theta) + "\n")
                file.write("R_init=" + str(R_init) + "_R_query=" + str(R_query) +
                           "_t=" + str(t) + "_s_dot=" + str(s_dot) + "_s=" + str(s) + "\n")
                file.close()

                inCluster = misError.ourAlgorithm(G, k, R_init, R_query, t, s, s_dot, theta, timeFlag)
                clustering = misError.toClusteringSets(inCluster)

                # accuracy
                accuracy = misError.getMatching(clustering, plantedClusters)
                print("Accuracy of our algorithm: " + str(accuracy))


                # write txt
                file = open("Results/error.txt", "a")
                file.write("accuracy is: " + str(accuracy) + "\n")
                file.write("error is: " + str(1-accuracy) + "\n")
                file.write("-------------------------------\n\n")
                file.close()
