from utils import tour_length, ns, gap, get_opt_value
import time
from networkx.algorithms.approximation import christofides
from tqdm import tqdm
import pickle
import pandas as pd
from utils import beam_search
import os


if __name__ == "__main__":
    # Open a .csv file to write the results
    beam_size = 50
    filename = f"output/results_U_BS_{beam_size}.csv"

    run = False  # Set to True to run experiments, False to generate the LateX table

    if run == False:
        print("Generating LaTeX table from " + filename)
        assert os.path.exists(filename), "File " + filename + " does not exist. Set run = True to generate it."

    if run:
        _ = input("Press any key to continue and overwrite the file " + filename + " ...")
        F = open(filename, "w+")
        header = ("n,k_str,opt,christofides,christofides_time," +
                  "soft_dist+BS,soft_dist+BS_time,DIFUSCO+BS,DIFUSCO+BS_time,GNNGLS+BS,GNNGLS+BS_time,GNNAR+BS,GNNAR+BS_time\n")
        F.write(header)
        F.close()

        for n, sample_size in ns:

            # Open F to append
            F = open(filename, "a")

            print("\nProcessing n =", n)
            for k in tqdm(range(sample_size)):
                k_str = str(k).zfill(3)
                with open(f"data/tsp_uniform/{n}_{k_str}.pkl", "rb") as f:
                    G = pickle.load(f)

                # Get optimal value back
                opt = get_opt_value(G)

                # Run Christofides algorithm
                start_ch = time.time()
                tour_christofides = christofides(G)
                cost_christofides = tour_length(G, tour_christofides)
                time_christofides = time.time() - start_ch
                gap_christofides = 100 * gap(opt, cost_christofides)

                to_write = [str(n), k_str, opt, gap_christofides, time_christofides]
                for prediction in ['soft_dist', 'DIFUSCO', 'GNNGLS', 'GNNAR']:
                    tour, runtime = beam_search(G, prediction_key=prediction, beam_width=beam_size)
                    if tour != None:
                        cost_prediction_method = tour_length(G, tour)
                        gap_prediction_method = 100 * gap(opt, cost_prediction_method)
                        to_write += [gap_prediction_method, runtime]
                    else:
                        to_write += [float('inf'), -1]

                F.write(",".join(map(str, to_write)) + "\n")

            F.close()

    else:
        df = pd.read_csv(filename)

        n_list = [x for x, _ in ns]

        decoders = ['BS']
        methods = ['soft_dist', 'DIFUSCO', 'GNNGLS', 'GNNAR']

        table_data = {decoders[i]: {method: [] for method in methods} for i in range(len(decoders))}

        for decoder in decoders:
            for predictor in methods:
                for n in n_list:
                    sub_table_gaps = df[(df['n'] == n)][f'{predictor}+{decoder}']

                    mean_value_gaps = sub_table_gaps.mean()
                    std_value_gaps = sub_table_gaps.std()

                    sub_table_runtime = df[(df['n'] == n)][f'{predictor}+{decoder}_time']
                    mean_value_runtime = sub_table_runtime.mean()
                    std_value_runtime = sub_table_runtime.std()

                    table_data[decoder][predictor].append(
                        f"{mean_value_gaps:.2f} $\pm$ {std_value_gaps:.2f} & {mean_value_runtime:.2f} $\pm$ {std_value_runtime:.2f}")

        # Now print the table in the desired format
        print("n", " & ".join(map(str, n_list)))
        for decoder in decoders:
            for q, method in enumerate(methods):
                if q == 0:
                    print(f"\multirow{{4}}{{*}}{{{decoder}}} & {method} & ", end="")
                else:
                    print(f" & {method} & ", end="")
                row = table_data[decoder][method]
                # print(f"{decoder} {method}  " + " & ".join(row))
                print(" & ".join(row))
                print(" \\\\")
            print("\\hline")