import sys
import network_clustering as nc

def run(n_nodes, d_out, n_obs, noise):
   G, tensor, Q = nc.benchmark_data(n_nodes, d_out, n_obs, noise)
   for method in methods:
      distance_matrix = nc.compute_distances(tensor, method, Q = Q)
      yield nc.cluster(distance_matrix, benchmark = True, ground_truth = tensor.y)

def run_batch(dimension, repeat = 10):
   n_nodes = (200,) if dimension != "netsize" else (100, 150, 200, 250, 300, 350, 400, 450, 500)
   d_outs = (2,) if dimension != "netnoise" else (0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5)
   n_obs = (300,) if dimension != "obscount" else (100, 200, 300, 400, 500, 600, 700, 800, 900)
   noises = (1.0,) if dimension != "obsnoise" else (0.0, 0.5, 1.0, 1.25, 1.5, 1.75, 2.0, 2.5, 3.0)
   with open(f"fig_3{dimension2figure[dimension]}.csv", 'w') as f:
      f.write("n_nodes\td_out\tn_obs\tnoise\t%s\n" % '\t'.join(methods))
      for n_node in n_nodes:
         for d_out in d_outs: 
            for n_ob in n_obs:
               for noise in noises:
                  for _ in range(repeat):
                     sys.stderr.write(f"{n_node}\t{d_out}\t{n_ob}\t{noise}\t{_}\n")
                     result = [n_node, d_out, n_ob, noise]
                     for ami in run(n_node, d_out, n_ob, noise):
                        result.append(ami)
                     f.write("%s\n" % '\t'.join([str(x) for x in result]))
                     f.flush()

dimension2figure = {
   "obsnoise": "ab",
   "netnoise": "cd",
   "netsize": "ef",
   "obscount": "gh",
}

methods = ["flat", "nvd", "emb", "n2v", "gcn", "nvd+emb", "nvd+gcn+emb", "gcn+emb", "n2v+emb"]

run_batch("obsnoise")
run_batch("netnoise")
run_batch("netsize")
run_batch("obscount")
