"""
This script evaluate the impact of applying 2-opt improvement to tours generated by different decoding strategies
"""
from tqdm import tqdm
import pickle
from utils import get_opt_value, tour_length
from utils import greedy_with_probabilities_edge, beam_search, two_opt
from utils import gap, ns
from networkx.algorithms.approximation import christofides
import networkx as nx
from chrp import chrp
import time
import pandas as pd

def select_decoding_strategy(prediction, decoding, G, beam_size=None):
    """
    Select the decoding strategy based on the specified method.

    Parameters
    ----------
    prediction : str
        The key in the graph G that contains the prediction probabilities.
    decoding : str
        The decoding strategy to use. Options are "CHRP", "BS", or "G2".
    G : networkx.Graph
        The input graph representing the TSP instance.
    beam_size : int, optional
        The beam size to use for beam search decoding. Required if decoding is "BS".

    Returns
    -------
    list
        The tour obtained from the selected decoding strategy.
    """
    # Check if the prediction exists
    if prediction not in G[0][1].keys():
        return float(
            'inf'), -1  # I know that this is the case for some predictors, -1 runtime to signal that it was not run

    if decoding == "CHRP":
        # Get all the proabilities
        prob_distribution = nx.get_edge_attributes(G, prediction)

        # If they are tuples, convert to single values
        # WE DON'T DIVIDE BY WEIGHT HERE Because morally ALPS keep the weight into account
        test = list(prob_distribution.values())[0]
        if isinstance(test, tuple):
            prob_distribution = {k: sum(v) for k, v in prob_distribution.items()}
        elif isinstance(test, float) or isinstance(test, int):
            prob_distribution = {k: v for k, v in prob_distribution.items()}
        else:
            raise ValueError(
                f"The prediction_key must contain either float or tuple values, this is {type(list(prob_distribution.values())[0])}")

        # We don't normalize the probabilities to sum up to 1; but the maximum should be at most 1, for ALPS, otherwise you may get negative costs; so if this is not the case, we scale down
        max_value = max(prob_distribution.values())
        if max_value > 1.0:
            prob_distribution = {k: v / max_value for k, v in prob_distribution.items()}

        for i, j in prob_distribution.keys():
            G[i][j]['prediction'] = prob_distribution[(i, j)]

        tour, _ = chrp(G, prediction='prediction', normalize='none')

    if decoding == "BS":
        assert beam_size != None, "Beam size must be specified for beam search decoding"
        tour, _ = beam_search(G, prediction_key=prediction, beam_width=beam_size)

    elif decoding == "G2":
        tour, _ = greedy_with_probabilities_edge(G, prediction)

    # RETURN
    if len(tour) == G.number_of_nodes() + 1:
        return tour
    elif len(tour) == 0:
        return float('inf')
    else:
        raise ValueError(f"The returned tour has length {len(tour)}, expected {G.number_of_nodes() + 1} or 0")


if __name__ == "__main__":
    # Open a .csv file to write the results
    run = False
    beam_size = 50

    filename = "output/results_U_2_opt_no_BS.csv"
    predictors = ['soft_dist', 'DIFUSCO']
    decodings = ['CHRP', 'G2', 'BS']


    if run:
        F = open(filename, "w+")
        header = "n,k_str,opt,christofides,christofides_2_opt,christofides_2_opt_swaps," + \
                    "soft_dist+CHRP,soft_dist+CHRP_2_opt,soft_dist+CHRP_2_opt_swaps," + \
                    "DIFUSCO+CHRP,DIFUSCO+CHRP_2_opt,DIFUSCO+CHRP_2_opt_swaps," +  \
                    "soft_dist+G2,soft_dist+G2_2_opt,soft_dist+G2_2_opt_swaps," + \
                    "DIFUSCO+G2,DIFUSCO+G2_2_opt,DIFUSCO+G2_2_opt_swaps," + \
                    "soft_dist+BS,soft_dist+BS_2_opt,soft_dist+BS_2_opt_swaps," + \
                    "DIFUSCO+BS,DIFUSCO+BS_2_opt,DIFUSCO+BS_2_opt_swaps\n"

        F.write(header)
        F.close()

        for n, sample_size in ns:

            # Open F to append
            F = open(filename, "a")


            print("\nProcessing n =", n)
            for k in tqdm(range(sample_size)):
                k_str = str(k).zfill(3)
                with open(f"data/tsp_uniform/{n}_{k_str}.pkl", "rb") as f:
                    G = pickle.load(f)

                # Get optimal value back
                opt = get_opt_value(G)

                # Run Christofides algorithm
                start_ch = time.time()
                tour_christofides = christofides(G)
                cost_christofides = tour_length(G, tour_christofides)
                time_christofides = time.time() - start_ch
                gap_christofides = 100 * gap(opt, cost_christofides)

                # Apply 2-opt to Christofides
                improved_tour_christofides, swaps_christofides = two_opt(G, tour_christofides)
                cost_christofides_2_opt = tour_length(G, improved_tour_christofides)
                gap_christofides_2_opt = 100 * gap(opt, cost_christofides_2_opt)

                to_write = [str(n), k_str, opt,
                            gap_christofides,
                            gap_christofides_2_opt, swaps_christofides]

                for decoding in decodings:
                    for predictor in predictors:
                        tour = select_decoding_strategy(predictor, decoding, G, beam_size=beam_size)
                        cost_tour = tour_length(G, tour)
                        gap_prediction_method = 100 * gap(opt, cost_tour)

                        # Apply 2-opt to the predicted tour
                        improved_tour, swaps_prediction_method = two_opt(G, tour)
                        cost_improved_tour = tour_length(G, improved_tour)
                        assert cost_improved_tour <= cost_tour, "2-opt did not improve the tour????"
                        gap_prediction_method_2_opt = 100 * gap(opt, cost_improved_tour)

                        to_write += [gap_prediction_method,
                                     gap_prediction_method_2_opt,
                                     swaps_prediction_method]


                F.write(",".join(map(str, to_write)) + "\n")
            F.close()


    else:
        df = pd.read_csv(filename)

        n_list = [x for x, _ in ns]

        table_data = {decodings[i]: {method: [] for method in predictors} for i in range(len(decodings))}

        for decoder in decodings:
            for predictor in predictors:
                for n in n_list:
                    # 2-opt values
                    sub_table_2_opt = df[(df['n'] == n)][f'{predictor}+{decoder}_2_opt']
                    mean_value_gaps = sub_table_2_opt.mean()
                    std_value_gaps = sub_table_2_opt.std()

                    # 2 -opt swap
                    sub_table_swaps = df[(df['n'] == n)][f'{predictor}+{decoder}_2_opt_swaps']
                    mean_value_swaps = sub_table_swaps.mean()
                    std_value_swaps = sub_table_swaps.std()

                    table_data[decoder][predictor].append(
                        f"{mean_value_gaps:.2f} $\pm$ {std_value_gaps:.2f} & {mean_value_swaps:.0f} $\pm$ {std_value_swaps:.0f}"
                    )

        # Now print the table in the desired format
        print("n", " & ".join(map(str, n_list)))
        # First Christofides
        print("Christofides &  & ", end="")
        for n in n_list:
            sub_table_2_opt = df[(df['n'] == n)]['christofides_2_opt']
            mean_value_gaps = sub_table_2_opt.mean()
            std_value_gaps = sub_table_2_opt.std()

            # 2 -opt swap
            sub_table_swaps = df[(df['n'] == n)]['christofides_2_opt_swaps']
            mean_value_swaps = sub_table_swaps.mean()
            std_value_swaps = sub_table_swaps.std()

            print(
                f"{mean_value_gaps:.2f} $\pm$ {std_value_gaps:.2f} & {mean_value_swaps:.0f} $\pm$ {std_value_swaps:.0f}",
                end=" & ")
        print(" \\\\")

        print("\\hline")
        for decoder in decodings:
            for q, method in enumerate(predictors):
                if q == 0:
                    print(f"\multirow{{2}}{{*}}{{{decoder}}} & {method} & ", end="")
                else:
                    print(f" & {method} & ", end="")
                row = table_data[decoder][method]
                # print(f"{decoder} {method}  " + " & ".join(row))
                print(" & ".join(row))
                print(" \\\\")
            print("\\hline")