import sys, torch
import numpy as np
import pandas as pd
import networkx as nx
import network_clustering as nc
from sklearn.cluster import DBSCAN
from sklearn.metrics import adjusted_mutual_info_score

device = "cuda" if torch.cuda.is_available() else "cpu"

def run_subfigure(dataset, index):
   sys.stderr.write(f"{dataset}\n\nReading data...\n")
   G = nx.read_edgelist(f"data/{dataset}_network.csv", delimiter = "\t", nodetype = int, data = [("weight", int),] if dataset == "tivoli" else False)
   O = pd.read_csv(f"data/{dataset}_nodeattributes.csv", sep = "\t")
   tensor = nc._make_tensor(G, O.drop("label", axis = 1).values, O["label"])
   # Q is the effective resistance matrix for the network embeddings, which needs to be computed only once
   sys.stderr.write("Caching Q...\n")
   Q = nc._ge_Q_gpu(G) if device == "cuda" else nc._ge_Q(G)
   sys.stderr.write("Clustering...\n")
   with open(f"fig_4{index}.csv", 'w') as f:
      f.write("%s\n" % '\t'.join(methods))
      for run in range(25):
         sys.stderr.write(f"Run #{run}\n")
         perfs = []
         for method in methods:
            distance_matrix = nc.compute_distances(tensor, method, Q = Q)
            tmp_clusters = []
            tmp_perfs = []
            for eps in np.linspace(distance_matrix[distance_matrix > 0].min(), distance_matrix.mean() / 2, num = 100):
               tmp_clusters.append(DBSCAN(eps = eps, min_samples = 2, metric = "precomputed").fit(distance_matrix))
               unclassified_indexes = np.where(tmp_clusters[-1].labels_ == -1)
               tmp_clusters[-1].labels_[unclassified_indexes] = np.arange(tmp_clusters[-1].labels_.max() + 1, tmp_clusters[-1].labels_.max() + 1 + unclassified_indexes[0].size)
               tmp_perfs.append(adjusted_mutual_info_score(tmp_clusters[-1].labels_, tensor.y))
            perfs.append(max(tmp_perfs))
         f.write("%s\n" % '\t'.join([str(x) for x in perfs]))
         f.flush()
   sys.stderr.write("\n\n")

methods = ["flat", "nvd", "emb", "n2v", "gcn", "nvd+emb", "gcn+emb", "n2v+emb", "nvd+gcn+emb"]

run_subfigure("tradeatlas", "a")
run_subfigure("littlesis", "b")
run_subfigure("tivoli", "c")
