import networkx as nx
import numpy as np
from scipy.cluster.hierarchy import DisjointSet
import time


##############################################################
# Parameters commonly shared between scripts
##############################################################

_SAMPLE_STYLE = ["sample", "top_n", "as_it_is"]
_SAMPLE_STYLE_NAME = ["Alg1", "Alg1Top", r"CHR$^+$"]

TSPLIB_INSTANCES  =  ['eil51.tsp', 'berlin52.tsp', 'st70.tsp', 'eil76.tsp', 'pr76.tsp', 'rat99.tsp', 'kroA100.tsp', 'kroB100.tsp', 'kroC100.tsp', 'kroD100.tsp', 'kroE100.tsp', 'rd100.tsp', 'eil101.tsp', 'lin105.tsp', 'pr107.tsp', 'pr124.tsp', 'bier127.tsp', 'ch130.tsp', 'pr136.tsp', 'pr144.tsp', 'ch150.tsp', 'kroA150.tsp', 'kroB150.tsp', 'pr152.tsp', 'u159.tsp', 'rat195.tsp', 'd198.tsp', 'kroA200.tsp', 'kroB200.tsp', 'ts225.tsp', 'tsp225.tsp', 'pr226.tsp', 'gil262.tsp', 'pr264.tsp', 'a280.tsp', 'pr299.tsp', 'lin318.tsp', 'rd400.tsp', 'fl417.tsp', 'pr439.tsp', 'pcb442.tsp', 'd493.tsp', 'u574.tsp', 'rat575.tsp', 'p654.tsp', 'd657.tsp', 'u724.tsp', 'rat783.tsp', 'pr1002.tsp', 'u1060.tsp', 'vm1084.tsp', 'pcb1173.tsp', 'd1291.tsp']

# Get instances size and sample size
ns = [(20, 100), (50, 100), (100, 100), (200, 100), (300, 100), (500, 100), (1000, 50)]

TOL = 1e-6
##############################################################


def tour_length(G, tour):
    """
    Compute the length of a tour in G.

    Parameters
    ----------
    G : networkx.Graph
      G should be a complete weighted undirected graph. Edge data key
      corresponding to the edge weight should be 'weight'.
    tour : list
        List of nodes representing the tour. The first and last node should be the same.

    Returns
    -------
    length : float
        Length of the tour.
    """
    return sum(G[i][j]['weight'] for i, j in zip(tour[:-1], tour[1:]))

def gap(opt, approx):
    """
    Compute the gap between the optimal value and an approximate value.

    Parameters
    ----------
    opt : float
        Optimal value.
    approx : float
        Approximate value.

    Returns
    -------
    gap : float
        Gap between the optimal value and the approximate value, as a floating number between 0 and 1.
    """
    return (approx - opt) / opt

def get_opt_value(G):
    """
    Get the optimal value from our data format.

    Parameters
    ----------
    G : networkx.Graph
        G should be a complete weighted undirected graph. Edge data key
        corresponding to the edge weight should be 'weight'. The optimal solution
        is stored in the edge attribute 'in_solution', which is 1 if the edge is
        in the optimal solution, 0 otherwise.
    Returns
    -------
    opt_value : float
        Optimal value of the TSP instance represented by G.
    """
    return sum(G[i][j]['weight'] * G[i][j]["in_solution"] for i, j in G.edges())

def add_prediction_to_G(G, sample_style, prediction_name, seed = None):
    # TODO probabòy you can archive this function somewhere else
    if sample_style not in _SAMPLE_STYLE:
        raise ValueError(f"sample_style must be one of {_SAMPLE_STYLE}")
    n = G.number_of_nodes()

    # Get the probability distribution from the graph
    prob_distribution = nx.get_edge_attributes(G, prediction_name)

    # You will need edges array for sampling
    edges_array = np.arange(0, n * (n - 1) // 2)

    if sample_style == "sample":
        assert seed != None, "You must provide a seed for sampling"
        # Normalize the probability distribution
        total = sum(prob_distribution.values())
        prob_distribution = np.array([v / total for v in prob_distribution.values()])

        # Sample the edges according to the probability distribution
        np.random.seed(seed)
        P = np.random.choice(edges_array, size=n, replace=False, p=prob_distribution)

        # Add to G an attribute prediction: it is equal to 1 if the edge is in P, else 0
        for i, edge in enumerate(G.edges()):
            if i in P:
                G.edges[edge]["prediction"] = 1
            else:
                G.edges[edge]["prediction"] = 0

    elif sample_style == "top_n":
        # Sort edges by probability distribution
        sorted_edges = sorted(prob_distribution.items(), key=lambda item: item[1], reverse=True)

        # Add to G an attribute prediction: it is equal to 1 if the edge is in P, else 0
        cont = 0
        for edge, _ in sorted_edges:
            if cont <= n:
                G.edges[edge]["prediction"] = 1
            else:
                G.edges[edge]["prediction"] = 0
            cont += 1

    elif sample_style == "as_it_is":
        # Jsut copy the prediction_name attribute to prediction
        for edge in G.edges():
            G.edges[edge]["prediction"] = G.edges[edge][prediction_name]

    return G

def get_optimal_tour_as_list(G):
    """
    Get the optimal tour as a list

    Parameters
    ----------
    G: nx.Graph
        The graph with the optimal tour

    Returns
    -------
    nodes_sorted: list
    """
    nodes_order = nx.get_node_attributes(G, 'opt_tour')
    nodes_sorted = [k for k, v in sorted(nodes_order.items(), key=lambda item: item[1])]
    assert len(nodes_sorted) == G.number_of_nodes(), "Something went wrong, the tour is shorter than the number of nodes"
    assert nodes_sorted[0] == 0, "The tour does not start with 0"
    nodes_sorted += [nodes_sorted[0]]
    return nodes_sorted

def greedy_with_probabilities_nearest_neighbor(G, prediction_key):
    """
    Run the greedy algorithm using the probabilities stored in prediction_key.
    Here, you iteratively select the nearest neighbor with the highest probability.

    Parameters
    ----------
    G: nx.Graph
        The graph with the probabilities stored in prediction_key.
    prediction_key: str
        The key in the edge attributes where the probabilities are stored.

    Returns
    -------
    tour: list
        The tour found by the greedy algorithm.
    running_time: float
        The time taken to compute the tour.
    """

    # Get all the proabilities
    prob_distribution = nx.get_edge_attributes(G, prediction_key)

    # If they are tuples, convert to single values by sum them and dividing by weight
    test = list(prob_distribution.values())[0]
    if isinstance(test, tuple):
        prob_distribution = {k: sum(v) / (G[k[0]][k[1]]["weight"] + TOL) for k, v in prob_distribution.items()}
    elif isinstance(test, float) or isinstance(test, int):
        prob_distribution = {k: v / (G[k[0]][k[1]]["weight"] + TOL) for k, v in prob_distribution.items()}
    else:
        raise ValueError(f"The prediction_key must contain either float or tuple values, this is {type(list(prob_distribution.values())[0])}")

    # We don't normalize the probabilities to sum up to 1

    # We also want the running time
    start = time.time()
    tour = [0]

    while len(tour) < G.number_of_nodes():
        current_node = tour[-1]
        candidates = [j for j in G.neighbors(current_node) if j not in tour]
        if not candidates:
            break
        weights = {}
        for j in candidates:
            edge = (min(current_node, j), max(current_node, j))
            weights[j] = prob_distribution[edge]
        # Pick the maximum
        next_node = max(weights, key=weights.get)
        tour.append(next_node)

    # Return to the starting node
    tour.append(tour[0])
    end = time.time()
    return tour, end - start

def from_edges_to_tour(tour_edges):
    """
    Convert a list of edges representing a tour into a list of nodes representing the same tour.

    Parameters
    ----------
    tour_edges: list
        List of edges representing the tour.

    Returns
    -------
    tour: list
        List of nodes representing the tour.
    """
    G_temp = nx.Graph()
    G_temp.add_edges_from(tour_edges)

    tour = []
    current_node = 0
    while len(tour) < G_temp.number_of_nodes():
        tour.append(current_node)
        neighbors = list(G_temp.neighbors(current_node))
        next_node = neighbors[0] if neighbors[0] not in tour else neighbors[1]
        current_node = next_node

    tour.append(tour[0])
    return tour

def greedy_with_probabilities_edge(G, prediction_key):
    """
        Run the greedy algorithm constructed as follows:
        Sort all the possible edges (i, j) in decreasing order of (p_ij + p_ji) / c_ij or p_ij / c_ij  .
        For each edge (i, j) in the list:
            – If inserting (i, j) into the graph results in a complete tour, insert (i, j) and terminate.
            – If inserting (i, j) results in a graph with cycles (of length < N), continue.
            – Otherwise, insert (i, j) into the tour.
        Return the extracted tour.

        (Adapted from Sun and Yang, "DIFUSCO: Graph-based Diffusion Solvers for Combinatorial Optimization")

        Parameters
        ----------
        G: nx.Graph
            The graph with the probabilities stored in prediction_key.
        prediction_key: str
            The key in the edge attributes where the probabilities are stored.

        Returns
        -------
        tour: list
            The tour found by the greedy algorithm.
        running_time: float
            The time taken to compute the tour.
        """

    # Get all the proabilities
    prob_distribution = nx.get_edge_attributes(G, prediction_key)

    # If they are tuples, convert to single values by sum them and dividing by weight
    test = list(prob_distribution.values())[0]
    if isinstance(test, tuple):
        prob_distribution = {k: sum(v) / (G[k[0]][k[1]]["weight"] + TOL) for k, v in prob_distribution.items()}
    elif isinstance(test, float) or isinstance(test, int):
        prob_distribution = {k: v / (G[k[0]][k[1]]["weight"] + TOL) for k, v in prob_distribution.items()}
    else:
        raise ValueError(
            f"The prediction_key must contain either float or tuple values, this is {type(list(prob_distribution.values())[0])}")

    # We don't normalize the probabilities to sum up to 1

    # Sort edges by prob_distribution
    sorted_edges = sorted(prob_distribution.items(), key=lambda item: item[1], reverse=True)

    tour_edges = []
    degree = {i: 0 for i in G.nodes()}

    # Keep also track of the runtime
    start_time = time.time()

    # Use the Union Find data structure to detect cycles
    disjoint_set = DisjointSet(list(range(G.number_of_nodes())))

    # Edge count to stop when we have n - 1 edges
    edge_count = 0

    for (a_i, b_i), _ in sorted_edges:
        a = disjoint_set[a_i]
        b = disjoint_set[b_i]

        if degree[a_i] < 2 and degree[b_i] < 2 and a != b:
            tour_edges.append((a_i, b_i))
            degree[a_i] += 1
            degree[b_i] += 1
            disjoint_set.merge(a, b)
            edge_count += 1

        if edge_count == G.number_of_nodes() - 1:
            break

    # Add the last edge to complete the tour
    remaining_nodes = [node for node, deg in degree.items() if deg == 1]
    assert len(remaining_nodes) == 2, "There should be exactly two nodes with degree 1"
    tour_edges.append((remaining_nodes[0], remaining_nodes[1]))

    tour = from_edges_to_tour(tour_edges)

    return tour, time.time() - start_time

def beam_search(G, prediction_key, beam_width=5, start_node=0):
    """
    Perform beam search to find a TSP tour using edge probabilities.
    Parameters
    ----------
    G : networkx.Graph
        G should be a complete weighted undirected graph. Edge data key
        corresponding to the edge weight should be 'weight'.
    prediction_key : str
        The key in the edge attributes where the probabilities are stored.
    beam_width : int, optional
        The beam width for the search. Default is 5.
    start_node : int, optional
        The starting node for the tour. Default is 0.

    Returns
    -------
    final_tour : list
        The best tour found by the beam search.
    runtime : float
        The time taken to compute the tour.
    """
    if prediction_key not in G[0][1].keys():
        return None, -1 # I know that this is the case for some predictors, -1 = no run

    # Get all the proabilities
    prob_distribution = nx.get_edge_attributes(G, prediction_key)

    # If they are tuples, convert to single values by sum them and dividing by weight
    test = list(prob_distribution.values())[0]
    if isinstance(test, tuple):
        prob_distribution = {k: sum(v) / (G[k[0]][k[1]]["weight"] + TOL) for k, v in prob_distribution.items()}
    elif isinstance(test, float) or isinstance(test, int):
        prob_distribution = {k: v / (G[k[0]][k[1]]["weight"] + TOL) for k, v in prob_distribution.items()}
    else:
        raise ValueError(
            f"The prediction_key must contain either float or tuple values, this is {type(list(prob_distribution.values())[0])}")

    # If the maximum is greater than 1, normalize the distribution
    max_prob = max(prob_distribution.values())
    if max_prob > 1.0:
        prob_distribution = {k: v / max_prob for k, v in prob_distribution.items()}

    n = len(G.nodes)

    # Log probs into a heatmap for easy access
    start = time.time()
    heatmap = np.zeros((n, n))

    for (i, j), prob in prob_distribution.items():
        heatmap[i][j] = np.log(prob)
        heatmap[j][i] = np.log(prob)  # Assuming undirected graph

    # A candidate is represented as: (cumulative_score, current_tour)
    # Start with just start_node in the tour
    beam = [(0.0, [start_node])]

    # We need to add n-1 more nodes to complete the path
    for _ in range(n - 1):
        candidates = []

        for score, tour in beam:
            last_node = tour[-1]

            # Identify nodes not yet visited
            visited = set(tour)

            for next_node in range(n):
                if next_node not in visited:
                    # Score is cumulative (sum of q_ij)
                    new_score = score + heatmap[last_node][next_node]
                    new_tour = tour + [next_node]
                    candidates.append((new_score, new_tour))

        # Sort candidates by score (descending) and keep the top k
        candidates.sort(key=lambda x: x[0], reverse=True)
        beam = candidates[:beam_width]

    # Complete the cycle by returning to the start node (0)
    final_candidates = []
    for score, tour in beam:
        last_node = tour[-1]
        total_score = score + heatmap[last_node][tour[0]]
        final_candidates.append((total_score, tour + [tour[0]]))

    # Final selection of the best complete tour. For the final selection, we will choose the one with the smallest weight
    # And not the one with the highest score. This is close to the intution behind the beam search in NLP tasks: You have a lot
    # of high-probability candidates, but you want the one that maximizes a different metric (e.g., BLEU score).
    tour_lengths = [tour_length(G, candidate) for _, candidate in final_candidates]
    best_idx = np.argmin(tour_lengths)

    runtime = time.time() - start

    final_tour = final_candidates[best_idx][1]

    assert final_tour[0] == final_tour[-1], "The tour is not a cycle"
    return final_tour, runtime


def two_opt(G, tour = None):
    """
    Performs 2-opt refinement on a tour until no further
    improving swaps can be found (Local Optimum).

    Psudo code from here
    https://en.wikipedia.org/wiki/2-opt

    Parameters
    ----------
    G : networkx.Graph
        G should be a complete weighted undirected graph. Edge data key
        corresponding to the edge weight should be 'weight'.
    tour : list, optional
        Initial tour as a list of nodes. If None, a default tour is used [0, 1, ..., n - 1].

    Returns
    -------
    best_tour : list
        The improved tour after applying 2-opt.
    cont_swaps : int
        The number of swaps performed to reach the local optimum.
    """
    if tour is None:
        tour = list(range(G.number_of_nodes())) + [0]

    tour = np.array(tour)
    n = G.number_of_nodes()

    # Distance matrix for quick lookup
    nodes = list(G.nodes())
    adj = nx.to_numpy_array(G, nodelist=nodes, weight='weight')

    # Map node labels to indices for the adjacency matrix
    node_to_idx = {node: i for i, node in enumerate(nodes)}
    tour_idx = np.array([node_to_idx[node] for node in tour])

    best_tour = tour_idx
    cont_swaps = 0
    improved = True

    while improved:
        improved = False
        for i in range(n - 1):  # Start of first edge (i, i+1)
            for j in range(i + 2, n):  # Start of second edge (j, j+1)
                # In 2-opt, we replace edges (i, i+1) and (j, j+1)
                # with (i, j) and (i+1, j+1)

                # Current edges
                a, b = best_tour[i], best_tour[i + 1]
                c, d = best_tour[j], best_tour[j + 1]

                # Delta calculation: (New edges cost) - (Old edges cost)
                # If delta < 0, the tour length decreases
                delta = (adj[a, c] + adj[b, d]) - (adj[a, b] + adj[c, d])

                if delta < -1e-9:  # Use a small epsilon for float precision
                    # Reverse the segment in-place
                    best_tour[i + 1: j + 1] = best_tour[i + 1: j + 1][::-1]
                    improved = True
                    cont_swaps += 1
                    # A 'break' here is "first improvement" (faster)
                    # Removing the break is "best improvement" (higher quality)
                    break
            if improved:
                break

    # Convert indices back to original node labels
    final_tour = [nodes[idx] for idx in best_tour]
    return final_tour, cont_swaps
