import networkx as nx 
import numpy as np 
from code import *
import sklearn
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt


# load graph #1 to form oracle
print('loading graphs')
g1 = nx.read_edgelist("datasets/graph1.txt",create_using=nx.DiGraph(), nodetype = int).to_undirected()
adj_matrix1 = nx.to_numpy_matrix(g1)

# load graph #5 for testing
g5 = nx.read_edgelist("datasets/graph5.txt",create_using=nx.DiGraph(), nodetype = int).to_undirected()
adj_matrix5 = nx.to_numpy_matrix(g5)

# get embeddings
print('getting graph embeddings')
embedding = sklearn.manifold.SpectralEmbedding(affinity = 'precomputed',n_components=2)
output1 = embedding.fit_transform(adj_matrix1)
output5 = embedding.fit_transform(adj_matrix5)

# get k means ++ cost for graph #5

kpp_cost = 0.0
k = 10
d = 2
for i in range(20):
	centroids5 = kpp2(output5, k, d)
	kpp_cost  += 0.05*(k_means_cost(output5, centroids5)[1])


# get distances to form nearest neighbor predictor
distances5 = euclidean_distances(output5, output1)

# get best clustering of graph #1 using k means ++ seeding and tons of lloyd steps
kmeans_scikit_train1 = KMeans(n_clusters=k).fit(output1)
train_labels1 = kmeans_scikit_train1.labels_

print('getting predictor labels')
# get predictor labels
oracle_labels5 = np.zeros(output5.shape[0])
for i in range(output5.shape[0]):
	curr_label = train_labels1[np.argmin(distances5[i,:])]
	oracle_labels5[i] = curr_label

# get algorithm cost
min_cost = float('inf')
for p in np.linspace(.01, .15, 15):
	algo_centers = algo1(output5, oracle_labels5.astype(int), k, p)
	algo_labels, cost = k_means_cost(output5, algo_centers)
	if cost < min_cost:
		min_cost = cost

print('kmeans++ cost:', 1.0, 'algorithm cost:', min_cost/kpp_cost)


# Perturbing oracle by noise - uncomment to iterate over noise values (for example Figure 1c)


# print('generating figure from paper')
# print('perturbing oracle by noise')
# # perturb predictor labels to generate Figure 5
# d = 2
# k = 10
# noise_vals = np.linspace(.01, .25, 50)
# oracle_cost_noise = []
# oracle_cost_noise_std = []
# algo_cost_noise = []
# algo_cost_noise_std = []
# for noise in noise_vals:
# 	curr_oracle = []
# 	curr_algo= []
# 	for j in range(10):
# 		curr_oracle_labels5 = np.zeros(output5.shape[0])
# 		for i in range(output5.shape[0]):
# 			curr_label = train_labels1[np.argmin(distances5[i,:])]
# 			if np.random.random() < 1-noise:
# 				oracle_labels5[i] = curr_label
# 			else:
# 				oracle_labels5[i] = np.random.randint(0,k)
# 		curr_oracle.append(kmeans_cost_label(output5, oracle_labels5, d, k)[1]/kpp_cost)
		

# 		min_cost = float('inf')
# 		for p in np.linspace(.01, .15, 15):
# 			algo_centers = algo1(output5, oracle_labels5.astype(int), k, p)
# 			algo_labels, cost = k_means_cost(output5, algo_centers)
# 			if cost < min_cost:
# 				min_cost = cost
# 		curr_algo.append(min_cost/kpp_cost)
# 	algo_cost_noise.append(np.median(curr_algo))
# 	algo_cost_noise_std.append(np.std(curr_algo))
# 	oracle_cost_noise.append(np.median(curr_oracle))
# 	oracle_cost_noise_std.append(np.std(curr_oracle))

# oracle_cost_noise = np.array(oracle_cost_noise)
# oracle_cost_noise_std = np.array(oracle_cost_noise_std)
# algo_cost_noise = np.array(algo_cost_noise)
# algo_cost_noise_std = np.array(algo_cost_noise_std)

# plt.figure(figsize=(10,7))
# plt.title("Dataset: Oregon Spectral Clustering, Graph #5, k=10", fontsize=20)
# plt.plot(noise_vals, algo_cost_noise, '.--', label="Our Alg", linewidth=2.5, markersize=5)
# plt.plot(noise_vals, oracle_cost_noise , '.--', label="Predictor", c='C1', linewidth=2.5, markersize=5)
# plt.plot(noise_vals, [1]*50, '-', label="kmeans++", color='C3', linewidth=2.5, markersize=5)
# plt.fill_between(noise_vals,  algo_cost_noise + algo_cost_noise_std,  algo_cost_noise - algo_cost_noise_std, alpha=.2)
# plt.fill_between(noise_vals,  oracle_cost_noise + oracle_cost_noise_std,  oracle_cost_noise - oracle_cost_noise_std, alpha=.2)
# plt.xlabel("Corruption %", fontsize=20)
# plt.ylabel("Clustering Cost", fontsize=20)
# plt.xticks(fontsize=20)
# plt.yticks(fontsize=20)
# plt.legend(fontsize=20)
# plt.show()
