from utils import _SAMPLE_STYLE, _SAMPLE_STYLE_NAME, ns
from utils import get_opt_value, add_prediction_to_G, gap, tour_length
import os
from tqdm import tqdm
import pickle
from chrp import chrp
import matplotlib.pyplot as plt
import numpy as np


prediction_name = "soft_dist"
prediction_name_nice = "SoftDist"

# Number of samples to generate P
number_of_samples = 10

# If the file does not exist it the run
if not os.path.exists(f"output/gaps_compare_methods_for_generating_P_{prediction_name}.pkl"):
    gaps = {n : {style: [] for style in _SAMPLE_STYLE} for n, _ in ns}

    for n, seed_max in ns:
        print("\nProcessing n =\n", n, flush=True)
        for seed in tqdm(range(0, seed_max)):
            k = str(seed).zfill(3)
            with open(f"data/tsp_uniform/{n}_{k}.pkl", "rb") as f:
                G = pickle.load(f)

            # Get the optimal value
            opt_val = get_opt_value(G)

            for sample_style in _SAMPLE_STYLE:
                if sample_style == "sample":
                    for seed_sampling in range(number_of_samples):  # Different seeds for sampling
                        # Add prediction to G according to sample_style
                        G = add_prediction_to_G(G, sample_style, prediction_name, seed=n * seed * seed_sampling) # This should create enough randomness

                        # Run ALPS with this G
                        tour_chrp, _ = chrp(G, prediction="prediction")
                        chrp_val = tour_length(G, tour_chrp)

                        # Compute the gap
                        g = gap(opt_val, chrp_val)

                        gaps[n][sample_style].append(100 * g)  # Store gap in percentage! Note: Now the length og gaps[n][sample_style] will be larger for "sample" style due to multiple samples
                        # gap_of_instance_1_seed_1, gap_of_instance_1_seed_2, ..., gap_of_instance_1_seed_m, gap_of_instance_2_seed_1, ... , gap_of_instance_m_seed_30

                else:
                    # Add prediction to G according to sample_style (not sample)
                    G = add_prediction_to_G(G, sample_style, prediction_name)

                    # Run CHR^+ with this G
                    tour_chrp, _ = chrp(G, prediction="prediction")
                    chrp_val = tour_length(G, tour_chrp)

                    # Compute the gap
                    g = gap(opt_val, chrp_val)

                    gaps[n][sample_style].append(100 * g)  # Store gap in percentage!
                    # Here we just have: gap_of_instance_1, gap_of_instance_2, ..., gap_of_instance_m

    # Save gaps to a file
    with open(f"output/gaps_compare_methods_for_generating_P_{prediction_name}.pkl", "wb") as f:
        pickle.dump(gaps, f)

else:
    print("Loading precomputed gaps...", flush=True)
    with open(f"output/gaps_compare_methods_for_generating_P_{prediction_name}.pkl", "rb") as f:
        gaps = pickle.load(f)

# Plot the results
plt.figure()
for i, sample_style in enumerate(_SAMPLE_STYLE):
    x = range(len(ns))
    y = [np.mean(gaps[n][sample_style]) for n, _ in ns]
    y_std = [np.std(gaps[n][sample_style]) for n, _ in ns]

    plt.plot(x, y, 'o-', label=_SAMPLE_STYLE_NAME[i])

    # Add error bars
    plt.fill_between(x, np.array(y) - np.array(y_std), np.array(y) + np.array(y_std), alpha=0.2)

plt.xticks(range(len(ns)), [n for n, _ in ns])
plt.xlabel("n")
plt.ylabel("Gap (%)")
plt.legend()
plt.grid()
plt.savefig("./figures/compare_methods_for_generating_P_" + prediction_name + ".png")
plt.show()
