import syntheticData
import dotProductOracle
import networkx as nx
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt


# find the appropriate parameters for dotProductOracle
# and fine the appropriate theta for our oracle


def drawSBM(n, k, p, q, R_init, R_query, t, s):
    N = k * n
    syntheticData.SBM(n, k, p, q)

    data_path = "./SyntheticData/" + "n=" + str(n) + "_k=" + str(k) + \
                "_p=" + str(p) + "_q=" + str(q) + ".csv"

    # read csv
    E = pd.read_csv(data_path)

    # generate a graph
    G = nx.Graph()
    # add nodes
    for i in range(0, N):
        G.add_node(i)
    # add edges
    G.add_edges_from([(u, v) for _, u, v in E.itertuples()])

    # nx.draw(G, node_size=20, alpha=0.3)
    # plt.savefig("./n="+str(n)+",k="+str(k)+",p="+str(p)+",q="+str(q)+".PNG")
    # plt.show()

    oracle = dotProductOracle.dotProductOracle(G, k, R_init, R_query, t, s)

    same_dot_k = []  # is a list of length k, each entry is also a list
    same_dot = []
    same_x = []
    same_x1 = []
    count = 0
    for tt in range(0, k):
        start = tt * n
        end = start + n
        dot_product = []
        for i in range(start, start + 5):
            for j in range(i + 1, min(i + 6, end)):
                pro = float(oracle.SpectralDotProductOracle(i, j))
                dot_product.append(pro)
                same_dot.append(pro)
                same_x.append(1)
                count = count + 1
                same_x1.append(count)
        same_dot_k.append(dot_product)

    differ_dot = []
    differ_x = []
    differ_x1 = []
    count = 0
    for tt in range(0, k):
        start = int(tt * n)
        end = int(start + n)
        for i in range(start, start + 5):
            for j in range(end, min(end + 5, N)):
                differ_dot.append(float(oracle.SpectralDotProductOracle(i, j)))
                differ_x.append(2)
                count = count + 1
                differ_x1.append(count)

    # draw theta_graph
    print("draw density")
    sns.distplot(same_dot,
                 bins=15,
                 # hist=False,
                 # kde = False,
                 # norm_hist = False,
                 hist_kws={'color': 'seagreen', 'histtype': 'bar', 'alpha': 0.7},
                 kde_kws={'color': 'seagreen', 'linestyle': '-', 'linewidth': 3, 'alpha': 0.7},
                 label='intra-cluster'
                 )

    sns.distplot(differ_dot,
                 bins=15,
                 # hist=False,
                 # kde = False,
                 # norm_hist = False,
                 hist_kws={'color': 'indianred', 'histtype': 'bar', 'alpha': 0.7},
                 kde_kws={'color': 'indianred', 'linestyle': '-', 'linewidth': 3, 'alpha': 0.7},
                 label='inter-cluster'
                 )
    plt.legend()
    plt.title("R_i="+str(R_init)+",R_q="+str(R_query)+",t="+str(t)+",s="+str(s))
    # plt.savefig("./n="+str(n)+",k="+str(k)+",p="+str(p)+",q="+str(q)+".PNG")
    plt.show()





n = 1000
k = 3
q_list = [0.002]
p_list = [0.02, 0.025, 0.03, 0.035, 0.04, 0.05, 0.06, 0.07]
# p_list = [0.03]

R_init_list = [1000, 2000]
R_query_list = [150, 200]
t_list = [25, 40]
s_list = [10, 20]

for R_init in R_init_list:
    for R_query in R_query_list:
        for t in t_list:
            for s in s_list:
                for p in p_list:
                    for q in q_list:
                        drawSBM(n, k, p, q, R_init, R_query, t, s)
