import numpy as np
from methods.reup import bayesian_utils
from sklearn.neighbors import NearestNeighbors, kneighbors_graph, radius_neighbors_graph
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import dijkstra, shortest_path
import networkx as nx

def bayesian_build_graph(data, pos_Sigma, pos_m, n_neighbors, diag=True):
    
    if diag:
        eval_Sigma = np.eye(pos_Sigma.shape[0])
        np.fill_diagonal(eval_Sigma, pos_Sigma.diagonal())
    else:
        eval_Sigma = pos_Sigma

    def expected_mahalanobis(x_i, x_j):
        x_i = x_i.reshape(-1, 1)
        x_j = x_j.reshape(-1, 1)
        expected_A = eval_Sigma
        d = (x_i - x_j).T @ expected_A @ (x_i - x_j)
        return d

    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree', metric=expected_mahalanobis).fit(data)
    graph = nbrs.kneighbors_graph(data, mode='distance').toarray()
    #graph = graph + graph.T

    #inds = nbrs.kneighbors_graph(data, mode='connectivity').toarray()
    #inds = inds + inds.T

    return graph

#------------------------------------------------------------------------------------------------
def mahalanobis_dist(x, y, A):
    return np.sqrt((x - y).T @ A @ (x - y))

def build_graph(data, A_opt, is_knn, n, cost="mahalanobis"):
    def dist_mahalanobis(x, y):
        return np.sqrt((x - y).T @ A_opt @ (x - y))
    
    def dist_l1(x, y):
        return np.linagl.norm(x - y, ord=1)

    if cost == "mahalanobis":
        dist = dist_mahalanobis
    else:
        dist = dist_l1

    nbrs = NearestNeighbors(n_neighbors=n, algorithm='ball_tree', metric=dist).fit(data)

    if is_knn:
        graph = nbrs.kneighbors_graph(data, mode="distance").toarray()
    else:
        graph = radius_neighbors_graph(data, radius=n, metric="pyfunc", func=dist, n_jobs=-1)

    return graph

def shortest_path_graph(graph, positive_data_idx):
    G = nx.from_numpy_array(graph)
    
    path_length, path = nx.single_source_bellman_ford(G, 0)
    path_length = dict(sorted(path_length.items(), key=lambda item: item[1]))

    for key in path_length:
        if key in positive_data_idx: 
            min_idx = key
            shortest_path = path[key]
            shortest_length = path_length[key]
            return shortest_length, min_idx, shortest_path
    
def eval_cost(A, data, path, weight=1.0, diag=True, cost="mahalanobis"):
    if diag:
        eval_A = np.eye(A.shape[0])
        np.fill_diagonal(eval_A, A.diagonal())
    else:
        eval_A = A
    
    l = len(path)
    res = 0
    for i in range(l - 1):
        if cost == "mahalanobis":
            cost = (data[path[i + 1]] - data[path[i]]).T @ eval_A @ (data[path[i + 1]] - data[path[i]])
        else:
            cost = np.linalg.norm((A @ data[path[i + 1]] - eval_A @ data[path[i]]), ord=1)
        res = res +  weight * cost

    return np.sqrt(res)

if __name__ == '__main__':
    data = np.random.rand(100, 2)
    A = np.random.rand(2, 2)
    A = A @ A.T

    graph = build_graph(data, A, True, 15)
    print(shortest_path_graph(graph))
