import statistics
import misError
import syntheticData
import random
import pandas as pd
import networkx as nx


# eval the mis-error of oracle in [CPS15]
# for new graph, we need to find the appropriate t, r, theta
# just run several parameters for t, r , theta
# use the parameters that has small error


# """
n_list = [1000]
k_list = [3]
p_list = [0.08, 0.07, 0.06, 0.05, 0.04, 0.0375, 0.035, 0.0325, 0.03, 0.025]
q_list = [0.002]
theta_list = [0.00034, 0.00028, 0.00025, 0.00020, 0.00015, 0.000183, 0.00018, 0.000145, 0.000113, 0.000035]
t_list = [20, 20, 20, 20, 20, 20, 20, 20, 22, 25]
r_list = [1000, 1000, 1000, 1000, 1100, 950, 1100, 1150, 1000, 1100]
# """


# p=0.08, t=20, r=1000, theta=0.00034
# p=0.07, t=20, r=1000, theta=0.00028
# p=0.06, t=20, r=1000, theta=0.00025
# p=0.05, t=20, r=1000, theta=0.00020
# p=0.04, t=20, r=1100, theta=0.00015
# p=0.0375, t=20, r=950, theta=0.000183
# p=0.035, t=20, r=1100, theta=0.00018
# p=0.0325, t=20, r=1150, theta=0.000145
# p=0.03, t=22, r=1000, theta=0.000113
# p=0.025, t=25, r=1100, theta=0.000035

# n_list = [1000]
# k_list = [3]
# p_list = [0.025, 0.03, 0.035, 0.04, 0.05, 0.06, 0.07]
# q_list = [0.002]
#
# t_list = [25, 22, 20, 20, 20, 20, 20]
# r_list = [1100, 1000, 1100, 1100, 1000, 1000, 1000]
# theta_list = [0.000035, 0.000113, 0.00018, 0.00015, 0.00020, 0.00025, 0.00028]
s = 21


# synthetic data
for n_index in range(0, len(n_list)):
    for k_index in range(0, len(k_list)):
        for p_index in range(0, len(p_list)):
            for q_index in range(0, len(q_list)):
                n = n_list[n_index]
                k = k_list[k_index]
                p = p_list[p_index]
                q = q_list[q_index]

                theta = theta_list[p_index]
                t = t_list[p_index]
                r = r_list[p_index]

                N = k * n
                syntheticData.SBM(n, k, p, q)
                data_path = "./SyntheticData/" + "n=" + str(n) + "_k=" + str(k) + \
                            "_p=" + str(p) + "_q=" + str(q) + ".csv"

                # write txt
                file = open("./Results/cps15error.txt", "a")
                file.write("n=" + str(n) + "_k=" + str(k) + "_p=" + str(p) + "_q=" + str(q) + "\n")
                file.write("s=" + str(s) + "\n")
                file.write("t=" + str(t) + "\n")
                file.write("r=" + str(r) + "\n")
                file.write("theta=" + str(theta) + "\n")
                file.close()

                # read data
                E = pd.read_csv(data_path)

                # generate the graph
                G = nx.Graph()
                for i in range(0, N):
                    G.add_node(i)
                G.add_edges_from([(u, v) for _, u, v in E.itertuples()])

                temp = 0
                repeat = 30
                acc = []
                err = []
                for num in range(0, repeat):
                    # when selecting a sampling size of N, we calculate the overall accuracy
                    random_list = random.sample(range(N), N)  #  random.sample(sequence, k)
                    misError.defineRandomList(random_list)

                    # ground_truth
                    plantedClusters = misError.toClusteringSets([int(u / n) for u in random_list])

                    # our
                    inCluster = misError.cps15Algorithm(G, s, k, t, r, theta)
                    clustering = misError.toClusteringSets(inCluster)

                    # accuracy
                    accuracy = misError.getMatching(clustering, plantedClusters)
                    print("s=" + str(s) + "---Accuracy of cps15 algorithm: " + str(accuracy))

                    temp = temp + accuracy

                    # write txt
                    file = open("./Results/cps15error.txt", "a")
                    file.write("accuracy is: " + str(accuracy) + "\n")
                    file.write("error is: " + str(1-accuracy) + "\n")
                    file.write("------------------------------------------------------\n")
                    file.close()

                    acc.append(accuracy)
                    err.append(1-accuracy)

                print("avg accuracy: " + str(temp/repeat))
                # write txt
                file = open("./Results/cps15error.txt", "a")
                file.write(str(acc) + "\n")
                file.write(str(err) + "\n")
                file.write("min-error: " + str(min(err)) + "\n")
                file.write("max-error: " + str(max(err)) + "\n")
                file.write("median-error: " + str(statistics.median(err)) + "\n")
                file.write("average accuracy is: " + str(temp/repeat) + "\n")
                file.write("average error is: " + str(1 - temp/repeat) + "\n")
                file.write("***********************************************************************\n\n\n")
                file.close()