import pickle
import sys
from tqdm import tqdm
import numpy as np
import networkx as nx
from scipy.spatial.distance import cdist
import os
from concorde import run_concorde

def soft_dist(G, name_prediction = "soft_dist", name_weight="weight_norm"):
    # SoftDist https://arxiv.org/pdf/2406.03503
    tau_dict = {500 : 0.066, 1000 : 0.0051, 10000 : 0.0018}
    n = G.number_of_nodes()
    d_tau_keys = [abs(n - n_bar) for n_bar in tau_dict.keys()]

    # Get the index of the minimum
    tau_min_idx = d_tau_keys.index(min(d_tau_keys))
    n_bar = list(tau_dict.keys())[tau_min_idx]
    tau = tau_dict[n_bar]

    for i in G.nodes():
        denominator = sum(np.exp(-G[i][j][name_weight] / tau) for j in G.nodes() if j != i)
        for j in G.nodes():
            if j != i:
                G[i][j][name_prediction] = np.exp(-G[i][j][name_weight] / tau) / denominator

    return G

if __name__ == "__main__":
    # output folder
    output_folder = "./tsp_uniform/"

    concorde_path = sys.argv[1]

    ns = [(20, 0, 100), (50, 0, 100), (100, 0, 100), (200, 0, 100), (300, 0, 100), (500, 0, 100), (1000, 0, 50)]

    for n, seed_min, seed_max in ns:
        print("Generating data for n = {}".format(n), flush=True)
        for seed in tqdm(range(seed_min, seed_max)):
            # Set the seed
            np.random.seed(seed)

            # Max size of integer I want to sample from
            coord_max = 2 * n

            # Instantiate an empty matrix
            X = []

            # Now loop on n
            while len(X) < n: # Unsless we do not have enough instances
                coords = np.random.randint(0, coord_max, (2, ))

                # check if the line coord is already there
                if not any((coords == x).all() for x in X):
                    # If not, append it
                    X.append(coords)


            # Create a graph
            G = nx.Graph()

            # Add nodes
            for i in range(n):
                G.add_node(i, coord=X[i])

            X = np.array(X)

            # Compute the Euclidean distance
            C = cdist(X, X)

            # Get the maximum value of C
            c_max = np.max(C)

            # Add weights
            for i in range(n):
                for j in range(i+1, n):
                    G.add_edge(i, j, rounded_weight=round(C[i,j]), weight=C[i, j], weight_norm = C[i, j] / c_max)

            ot, bb_nodes, time, tour, edges = run_concorde(G, cost=None, concorde_path=concorde_path,
                                                            get_tour=True, get_edges=True)


            assert len(tour) == n + 1


            # Set the attribute opt_tour
            tour_idx = [tour.index(i) for i in range(n)]

            # Add edge attribute opt_tour/in_solution

            for i in range(n):
                # set the NODE attribute opt tour
                G.nodes[i]["opt_tour"] = tour_idx[i]
                for j in range(i + 1, n):
                    #G[i][j]["LP"] = edges_LP[(i, j)]
                    if (i, j) in edges or (j, i) in edges: # Just in case
                        G[i][j]["opt_tour"] = 1
                        G[i][j]["in_solution"] = 1
                    else:
                        G[i][j]["opt_tour"] = 0
                        G[i][j]["in_solution"] = 0

            # Add softdist
            G = soft_dist(G, "soft_dist", "weight_norm")

            # Save it in pickle format
            seed_str = str(seed).zfill(3)
            pickle.dump(G, open(os.path.join(output_folder, "{}_{}.pkl".format(n, seed_str)), "wb"))