"""
This script does nothing; just prints something for writing the LaTeX table in the paper.
"""
import pandas as pd
from utils import ns

if __name__ == "__main__":

    df = pd.read_csv('./output/results_U.csv')

    # I want something like this
    """
    n 20 50 100 200 300 500 1000
        soft_dist
        DIFUSCO
    G1  GNNGLS
        GNNAR
    
        soft_dist
        DIFUSCO
    G2  GNNGLS
        GNNAR
        
        soft_dist
        DIFUSCO
    ALPSGNNGLS
        GNNAR
    """

    n_list = [x for x, _ in ns]

    decoders = ['G1', 'G2', 'ALPS']
    methods = ['soft_dist', 'DIFUSCO', 'GNNGLS', 'GNNAR']

    table_data = { decoders[i]: { method: [] for method in methods } for i in range(len(decoders)) }

    for decoder in decoders:
        for predictor in methods:
            for n in n_list:
                sub_table_gaps = df[(df['n'] == n)][f'{predictor}+{decoder}']

                mean_value_gaps = sub_table_gaps.mean()
                std_value_gaps = sub_table_gaps.std()

                sub_table_runtime = df[(df['n'] == n)][f'{predictor}+{decoder}_time']
                mean_value_runtime = sub_table_runtime.mean()
                std_value_runtime = sub_table_runtime.std()

                table_data[decoder][predictor].append(f"{mean_value_gaps:.2f} $\pm$ {std_value_gaps:.2f} & {mean_value_runtime:.2f} $\pm$ {std_value_runtime:.2f}")


    # Now print the table in the desired format
    print("n", " & ".join(map(str, n_list)))
    for decoder in decoders:
        for q, method in enumerate(methods):
            if q == 0:
                print(f"\multirow{{4}}{{*}}{{{decoder}}} & {method} & ", end="")
            else:
                print(f" & {method} & ", end="")
            row = table_data[decoder][method]
            #print(f"{decoder} {method}  " + " & ".join(row))
            print(" & ".join(row))
            print(" \\\\")
        print("\\hline")

    # Print also the average gap of Christofides & time
    for n in [20, 50, 100, 200, 300, 500, 1000]:
        sub_table_gaps = df[(df['n'] == n)]['christofides']
        mean_value_gaps = sub_table_gaps.mean()
        std_value_gaps = sub_table_gaps.std()

        sub_table_runtime = df[(df['n'] == n)]['christofides_time']
        mean_value_runtime = sub_table_runtime.mean()
        std_value_runtime = sub_table_runtime.std()

        print(f"{mean_value_gaps:.2f} $\pm$ {std_value_gaps:.2f} & {mean_value_runtime:.2f} $\pm$ {std_value_runtime:.2f}", end=" & ")