from tqdm import tqdm
import pickle
from utils import get_opt_value, tour_length
from utils import greedy_with_probabilities_edge, greedy_with_probabilities_nearest_neighbor
from utils import gap, ns
from networkx.algorithms.approximation import christofides
import networkx as nx
from chrp import chrp
import time

def select_decoding_strategy(prediction, decoding, G):
    # Check if the prediction exists
    if prediction not in G[0][1].keys():
        return float('inf'), -1 # I know that this is the case for some predictors, -1 runtime to signal that it was not run

    if decoding == "ALPS":
        # Get all the proabilities
        prob_distribution = nx.get_edge_attributes(G, prediction)

        # If they are tuples, convert to single values
        # WE DON'T DIVIDE BY WEIGHT HERE Because morally ALPS keep the weight into account
        test = list(prob_distribution.values())[0]
        if isinstance(test, tuple):
            prob_distribution = {k: sum(v) for k, v in prob_distribution.items()}
        elif isinstance(test, float) or isinstance(test, int):
            prob_distribution = {k: v for k, v in prob_distribution.items()}
        else:
            raise ValueError(
                f"The prediction_key must contain either float or tuple values, this is {type(list(prob_distribution.values())[0])}")

        # We don't normalize the probabilities to sum up to 1; but the maximum should be at most 1, for ALPS, otherwise you may get negative costs; so if this is not the case, we scale down
        max_value = max(prob_distribution.values())
        if max_value > 1.0:
            prob_distribution = {k: v / max_value for k, v in prob_distribution.items()}

        for i, j in prob_distribution.keys():
            G[i][j]['prediction'] = prob_distribution[(i, j)]

        tour, runtime = chrp(G, prediction='prediction', normalize='none')

    if decoding == "G1":
        tour, runtime = greedy_with_probabilities_nearest_neighbor(G, prediction)

    elif decoding == "G2":
        tour, runtime = greedy_with_probabilities_edge(G, prediction)

    # RETURN
    if len(tour) == G.number_of_nodes() + 1:
        return tour_length(G, tour), runtime
    elif len(tour) == 0:
        return float('inf'), runtime
    else:
        raise ValueError(f"The returned tour has length {len(tour)}, expected {G.number_of_nodes() + 1} or 0")


if __name__ == "__main__":
    # Open a .csv file to write the results
    F = open("output/results_U.csv", "w+")
    header = ("n,k_str,opt,christofides,christofides_time," +
              "soft_dist+G1,soft_dist+G1_time,soft_dist+G2,soft_dist+G2_time,soft_dist+ALPS,soft_dist+ALPS_time," +
                "DIFUSCO+G1,DIFUSCO+G1_time,DIFUSCO+G2,DIFUSCO+G2_time,DIFUSCO+ALPS,DIFUSCO+ALPS_time," +
                "GNNGLS+G1,GNNGLS+G1_time,GNNGLS+G2,GNNGLS+G2_time,GNNGLS+ALPS,GNNGLS+ALPS_time," +
                "GNNAR+G1,GNNAR+G1_time,GNNAR+G2,GNNAR+G2_time,GNNAR+ALPS,GNNAR+ALPS_time\n")
    F.write(header)
    F.close()

    for n, sample_size in ns:

        # Open F to append
        F = open("output/results_U.csv", "a")

        print("\nProcessing n =", n)
        for k in tqdm(range(sample_size)):
            k_str = str(k).zfill(3)
            with open(f"data/tsp_uniform/{n}_{k_str}.pkl", "rb") as f:
                G = pickle.load(f)

            # Get optimal value back
            opt = get_opt_value(G)

            # Run Christofides algorithm
            start_ch = time.time()
            tour_christofides = christofides(G)
            cost_christofides = tour_length(G, tour_christofides)
            time_christofides = time.time() - start_ch
            gap_christofides = 100 * gap(opt, cost_christofides)

            to_write = [str(n), k_str, opt, gap_christofides, time_christofides]
            for prediction in ['soft_dist', 'DIFUSCO', 'GNNGLS', 'GNNAR']:
                for method in ['G1', 'G2', 'ALPS']:
                    cost_prediction_method, runtime = select_decoding_strategy(prediction, method, G)
                    gap_prediction_method = 100 * gap(opt, cost_prediction_method)

                    to_write += [gap_prediction_method, runtime]

            F.write(",".join(map(str, to_write)) + "\n")

        F.close()
