"""
Evaluate different decoding strategies and different predictions on the instances of TSPLIB.
"""
import pickle
from utils import get_opt_value, tour_length, beam_search
from utils import greedy_with_probabilities_edge, greedy_with_probabilities_nearest_neighbor
from utils import gap, TSPLIB_INSTANCES, TOL
from networkx.algorithms.approximation import christofides
import networkx as nx
from chrp import chrp

def select_decoding_strategy(prediction, decoding, G, beam_size=None):
    """
    Select the decoding strategy based on the specified method.

    Parameters
    ----------
    prediction : str
        The key in the graph G that contains the prediction probabilities.
    decoding : str
        The decoding strategy to use. Options are "CHRP", "G1", "G2", or "BS".
    G : networkx.Graph
        The input graph representing the TSP instance.
    beam_size : int, optional
        The beam size to use for beam search decoding. Required if decoding is "BS".

    Returns
    -------
    float
        The length of the tour obtained from the selected decoding strategy.
    """
    # Check if the prediction exists
    if prediction not in G[0][1].keys():
        return float('inf') # I know that this is the case for some predictors

    if decoding == "CHRP":
        # Get all the proabilities
        prob_distribution = nx.get_edge_attributes(G, prediction)

        # If they are tuples, convert to single values by sum them and dividing by weight
        test = list(prob_distribution.values())[0]
        if isinstance(test, tuple):
            prob_distribution = {k: sum(v)  for k, v in prob_distribution.items()}
        elif isinstance(test, float) or isinstance(test, int):
            prob_distribution = {k: v for k, v in prob_distribution.items()}
        else:
            raise ValueError(
                f"The prediction_key must contain either float or tuple values, this is {type(list(prob_distribution.values())[0])}")

        # If the maximum probability is > then 1, normalize
        max_prob = max(prob_distribution.values())
        if max_prob > 1:
            prob_distribution = {k: v / max_prob for k, v in prob_distribution.items()}

        for i, j in prob_distribution.keys():
            G[i][j]['prediction'] = prob_distribution[(i, j)]

        tour, _ = chrp(G, prediction='prediction', normalize='none')

    if decoding == "G1":
        tour, _ = greedy_with_probabilities_nearest_neighbor(G, prediction)

    elif decoding == "G2":
        tour, _ = greedy_with_probabilities_edge(G, prediction)

    elif decoding == "BS":
        assert beam_size is not None, "Beam size must be provided for beam search decoding"
        tour, _ = beam_search(G, prediction_key=prediction, beam_width=beam_size)

    # RETURN
    if len(tour) == G.number_of_nodes() + 1:
        return tour_length(G, tour)
    elif len(tour) == 0:
        return float('inf')
    else:
        raise ValueError(f"The returned tour has length {len(tour)}, expected {G.number_of_nodes() + 1} or 0")


if __name__ == "__main__":

    # Open a .csv file to write the results
    F = open("output/results_T.csv", "w+")
    header = ("n,instance_name,opt,christofides," +
              "soft_dist+G1,soft_dist+G2,soft_dist+BS,soft_dist+CHRP," +
              "DIFUSCO+G1,DIFUSCO+G2,DIFUSCO+BS,DIFUSCO+CHRP," +
              "GNNGLS+G1,GNNGLS+G2,GNNGLS+BS,GNNGLS+CHRP," +
              "GNNAR+G1,GNNAR+G2,GNNAR+BS,GNNAR+CHRP\n")
    F.write(header)
    F.close()

    beam_size = 50

    for instance in TSPLIB_INSTANCES:
        F = open("output/results_T.csv", "a")

        instance_name = instance.split(".")[0]

        print("\nProcessing instance =", instance_name)
        with open(f"data/tsplib/{instance_name}.pkl", "rb") as f:
            G = pickle.load(f)

        n = G.number_of_nodes()

        # Get optimal value back
        opt = get_opt_value(G)

        # Run Christofides algorithm
        tour_christofides = christofides(G)
        cost_christofides = tour_length(G, tour_christofides)
        gap_christofides = 100 * gap(opt, cost_christofides)

        to_write = [str(n), instance_name, opt, gap_christofides]
        for prediction in ['soft_dist', 'DIFUSCO', 'GNNGLS', 'GNNAR']:
            for method in ['G1', 'G2', 'BS', 'CHRP']:
                cost_prediction_method = select_decoding_strategy(prediction, method, G, beam_size=beam_size)
                gap_prediction_method = 100 * gap(opt, cost_prediction_method)

                to_write.append(gap_prediction_method)

        F.write(",".join(map(str, to_write)) + "\n")

    F.close()
