"""
This script reproduce the results of the section "Our algorithm behaves smoothly as the prediction error decreases."
"""
import numpy as np
import pickle
from utils import get_optimal_tour_as_list, get_opt_value, greedy_with_probabilities_edge, tour_length
from networkx.algorithms.approximation import christofides
import networkx as nx
from chrp import chrp
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from utils import TOL

if __name__ == "__main__":
    # Set n
    n = 500

    # For 1000, in dataset d, you have 50 samples, else 100;
    n_instances = 100

    # For every instance and every epsilon, do the sampling with different seeds
    number_of_sampling = 5
    seed = 0

    # Set the epsilon values
    epsilon_list = np.arange(0, 1, 0.05)

    # Dictionary to store the gaps
    gaps = {}

    # Set a probability
    p = 0.9

    # With greedy?
    with_greedy = False


    # Check if you have the dictionary already
    if os.path.exists(f"./output/smoothness_with_epsilon_{n}_{n_instances}_{number_of_sampling}.pkl") == False:
        print("Generating outputs...", flush=True)
        for s in range(n_instances):
            print("Instance ", s, flush=True)
            with open(f"data/tsp_uniform/{n}_{str(s).zfill(3)}.pkl", "rb") as f:
                    G = pickle.load(f)

            # get the optimal tour
            optimal_tour = get_optimal_tour_as_list(G)

            # Get the optimal value
            opt = get_opt_value(G)

            # Run also plain Christofides for comparison
            tour_christofides = christofides(G)
            tour_christofides_cost = nx.path_weight(G, tour_christofides, weight='weight')
            gap_christofides = (tour_christofides_cost - opt) / opt * 100
            print("Christofides gap: ", gap_christofides, flush=True)

            # gaps[s][idx_epsilon][idx_sampling]
            # Is the gap of instance s, for epsilon epsilon_list[idx_epsilon],and the sample number idx_sampling
            gaps[s] = {'christofides': gap_christofides,
                       'alps': np.zeros((len(epsilon_list), number_of_sampling)),
                       'greedy': np.zeros((len(epsilon_list), number_of_sampling))}

            # If not with_greedy, remove the greedy entry
            if with_greedy == False:
                gaps[s].pop('greedy')


            for idx_epsilon, epsilon in tqdm(enumerate(epsilon_list), total = len(epsilon_list)):
                for sample in range(number_of_sampling):
                    edges_in_optimal_tour = [(i, j) for i, j in G.edges() if G[i][j]['opt_tour'] == 1]
                    edges_not_in_optimal_tour = [(i, j) for i, j in G.edges() if G[i][j]['opt_tour'] == 0]

                    # Set the seed
                    np.random.seed(seed)
                    # Increase the seed for the next sampling
                    seed += 1

                    # Define H-
                    w_H_minus = TOL  # Hence, when epsilon = 0, H- is empty
                    H_minus = []
                    while w_H_minus <= epsilon * opt:
                        # Randomly sample an edge from edges_in_optimal_tour
                        edge = edges_in_optimal_tour[np.random.randint(0, len(edges_in_optimal_tour))]
                        # Add the edge to H_minus if not already present
                        H_minus.append(edge)
                        w_H_minus += G[edge[0]][edge[1]]['weight']
                        # Remove the edge from edges_in_optimal_tour to avoid resampling
                        edges_in_optimal_tour.remove(edge)

                    # Define H+
                    w_H_plus = TOL
                    H_plus = []
                    while w_H_plus <= epsilon * opt:
                        # Randomly sample an edge from edges_not_in_optimal_tour
                        edge = edges_not_in_optimal_tour[np.random.randint(0, len(edges_not_in_optimal_tour))]
                        # Add the edge to H_plus if not already present
                        H_plus.append(edge)
                        w_H_plus += G[edge[0]][edge[1]]['weight']
                        # Remove the edge from edges_not_in_optimal_tour to avoid resampling
                        edges_not_in_optimal_tour.remove(edge)

                    if epsilon == 0:
                        assert len(H_minus) == len(H_plus) == 0, "H- and H+ should be empty when epsilon = 0"

                    prob_1_edges = [e for e in G.edges() if G[e[0]][e[1]]['opt_tour'] == 1 and e not in H_minus] + H_plus

                    # Add predictions
                    for e in G.edges():
                        if e in prob_1_edges:
                            G[e[0]][e[1]]['perturbed_prediction'] = p
                        else:
                            G[e[0]][e[1]]['perturbed_prediction'] = 0

                    # Run ALPS on perturbed predictions
                    tour_perturbed, _ = alps(G, prediction="perturbed_prediction", normalize='none')
                    tour_perturbed_cost = tour_length(G, tour_perturbed)
                    gap_perturbed = (tour_perturbed_cost - opt) / opt * 100
                    gaps[s]['alps'][idx_epsilon][sample] = gap_perturbed


                    if with_greedy:
                        # Run greedy on perturbed predictions
                        tour_greedy_perturbed, _ = greedy_with_probabilities_edge(G, prediction_key='perturbed_prediction')
                        tour_greedy_perturbed_cost = nx.path_weight(G, tour_greedy_perturbed, weight='weight')
                        gap_greedy_perturbed = (tour_greedy_perturbed_cost - opt) / opt * 100
                        gaps[s]['greedy'][idx_epsilon][sample] = gap_greedy_perturbed

            # Save the results
            with open(f"./output/smoothness_with_epsilon_{n}_{n_instances}_{number_of_sampling}.pkl", "wb") as g:
                pickle.dump(gaps, g)

    else:
        print("Just do the plotting...", flush=True)
        markersize = 10
        with open(f"./output/smoothness_with_epsilon_{n}_{n_instances}_{number_of_sampling}.pkl", "rb") as g:
            gaps = pickle.load(g)
        gaps_christofides = [gaps[s]['christofides'] for s in gaps]
        # Get the mean
        mean_gaps_christofides = np.mean(gaps_christofides)

        # Draw an horizontal line at mean_gaps_christofides
        plt.figure()

        #plt.title("Smoothness with ε, n = " + str(n) + " p = " + str(p))
        plt.hlines(mean_gaps_christofides, min(epsilon_list), max(epsilon_list), color='b', linestyle='--', label='CHR')

        # Now, for every epsilon, get the mean and std of ALPS and greedy
        mean_gaps_alps = []
        std_gaps_alps = []
        mean_gaps_greedy = []
        std_gaps_greedy = []

        for idx_epsilon, epsilon in enumerate(epsilon_list):
            gaps_alps_epsilon = [gaps[s]['alps'][idx_epsilon] for s in gaps]

            mean_gaps_alps.append(np.mean(gaps_alps_epsilon))
            std_gaps_alps.append(np.std(gaps_alps_epsilon))

            if with_greedy:
                gaps_greedy_epsilon = [gaps[s]['greedy'][idx_epsilon] for s in gaps]
                mean_gaps_greedy.append(np.mean(gaps_greedy_epsilon))
                std_gaps_greedy.append(np.std(gaps_greedy_epsilon))

        # Convert lists to numpy arrays (important for vector ops)
        mean_gaps_alps = np.array(mean_gaps_alps)
        std_gaps_alps = np.array(std_gaps_alps)

        # Plot mean lines
        plt.plot(epsilon_list, mean_gaps_alps, color='orange', label=r'CHR$^+$')


        if with_greedy:
            mean_gaps_greedy = np.array(mean_gaps_greedy)
            std_gaps_greedy = np.array(std_gaps_greedy)

            # Plot mean lines
            plt.plot(epsilon_list, mean_gaps_greedy, color='green', label='G2')

        # Shaded error regions (mean ± std)
        plt.fill_between(
            epsilon_list,
            mean_gaps_alps - std_gaps_alps,
            mean_gaps_alps + std_gaps_alps,
            color='orange',
            alpha=0.25
        )

        if with_greedy:
            plt.fill_between(
                epsilon_list,
                mean_gaps_greedy - std_gaps_greedy,
                mean_gaps_greedy + std_gaps_greedy,
                color='green',
                alpha=0.25
            )

        # Add also a shaded error region for christofides
        plt.fill_between(
            epsilon_list,
            mean_gaps_christofides - np.std(gaps_christofides),
            mean_gaps_christofides + np.std(gaps_christofides),
            color='blue',
            alpha=0.25
        )

        plt.xlabel('Synthetic prediction error ε')
        plt.ylabel("Optimality Gap (/%)")

        # Plot only from 0 ti 0.5 on x axis
        plt.xlim(0, 0.3)
        #
        # Plot from 0 to 30 on y axis
        plt.ylim(0, 20)

        plt.legend()

        # Savefig
        plt.savefig(f'figures/smoothness_with_epsilon_{n}_{n_instances}_{number_of_sampling}.png', dpi=300)
        plt.show()