import torch
import time
import numpy as np
import warnings
import torch.nn.functional as F
import torch
from joblib import Parallel, delayed
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import floyd_warshall


def compute_distance_matrix(input1, input2, metric='euclidean'):
    """A wrapper function for computing distance matrix.

    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
        metric (str, optional): "euclidean" or "cosine".
            Default is "euclidean".

    Returns:
        torch.Tensor: distance matrix.
        
    """
    # check input
    if isinstance(input1, np.ndarray):
        input1 = torch.from_numpy(input1)
    if isinstance(input2, np.ndarray):
        input2 = torch.from_numpy(input2)
    assert isinstance(input1, torch.Tensor)
    assert isinstance(input2, torch.Tensor)
    assert input1.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(
        input1.dim()
    )
    assert input2.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(
        input2.dim()
    )
    assert input1.size(1) == input2.size(1)

    if metric == 'euclidean':
        distmat = euclidean_squared_distance(input1, input2)
    elif metric == 'cosine':
        distmat = cosine_distance(input1, input2)
    else:
        raise ValueError(
            'Unknown distance metric: {}. '
            'Please choose either "euclidean" or "cosine"'.format(metric)
        )

    return distmat


def euclidean_squared_distance(input1, input2):
    """Computes euclidean squared distance.

    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.

    Returns:
        torch.Tensor: distance matrix.
    """
    m, n = input1.size(0), input2.size(0)
    mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
    mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    distmat = mat1 + mat2
    distmat.addmm_(mat1=input1, mat2=input2.t(), beta=1, alpha=-2)
    distmat.clamp_(min=0)
    return distmat


def cosine_distance(input1, input2):
    """Computes cosine distance.

    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.

    Returns:
        torch.Tensor: distance matrix.
    """
    input1_normed = F.normalize(input1, p=2, dim=1)
    input2_normed = F.normalize(input2, p=2, dim=1)
    distmat = 1 - torch.mm(input1_normed, input2_normed.t())
    distmat.clamp_(min=0, max=2)
    return distmat


def jensen_shannon_divergence(sparse_query_feature, sparse_gallery_feature, device=torch.device('cuda')):
    '''
    Computing the Jensen Shannon divergence    
    '''
    query_num = sparse_query_feature.shape[0]
    gallery_num = sparse_gallery_feature.shape[0]
    assert sparse_query_feature.shape[1]==sparse_gallery_feature.shape[1]

    if device == torch.device('cuda'):
        from reranker.sparse_divergence import sparse_divergence
        sparse_query_feature = torch.from_numpy(sparse_query_feature).to(device)
        sparse_gallery_feature = torch.from_numpy(sparse_gallery_feature).to(device)
        print('Computing Jensen Shannon Divergence')
        start_time = time.time()
        JS_divergence = sparse_divergence.jensen_shannon(sparse_query_feature, sparse_gallery_feature)
        torch.cuda.synchronize()
        end_time = time.time()
        print(f'===> Time cost {end_time-start_time}')
        return JS_divergence
        
    elif device == torch.device('cpu'):
        # First round computing
        def first_round(i):
            queryIndexNonZero = np.nonzero(sparse_query_feature[i])[0]
            expand_query_feature = np.tile(sparse_query_feature[i, queryIndexNonZero], (gallery_num, 1))
            query_plus_gallery_feature = expand_query_feature + sparse_gallery_feature[:, queryIndexNonZero]
            result = np.dot(sparse_query_feature[i, queryIndexNonZero].reshape(1,-1), np.log2(2*expand_query_feature/query_plus_gallery_feature).reshape(gallery_num,-1).T)
            return result

        # Second round computing
        def second_round(i):
            galleryIndexNonZero = np.nonzero(sparse_gallery_feature[i])[0]
            expand_gallery_feature = np.tile(sparse_gallery_feature[i, galleryIndexNonZero], (query_num, 1))
            gallery_plus_query_feature = expand_gallery_feature + sparse_query_feature[:, galleryIndexNonZero]
            result = np.dot(sparse_gallery_feature[i, galleryIndexNonZero].reshape(1,-1), np.log2(2*expand_gallery_feature/gallery_plus_query_feature).reshape(query_num,-1).T)
            return result

        print('Computing Jensen Shannon Divergence')
        start_time = time.time()
        JS_divergence1 = np.concatenate(Parallel(n_jobs=16)(delayed(first_round)(i) for i in range(query_num)))
        JS_divergence2 = np.concatenate(Parallel(n_jobs=16)(delayed(second_round)(i) for i in range(gallery_num)))
        JS_divergence = 0.5 * (JS_divergence1 + JS_divergence2.T)
        end_time = time.time()
        print(f'===> Time cost {end_time-start_time}')

        return JS_divergence

def all_pairs_shortest_paths(adjacency_matrix: torch.Tensor, max_threshold=50, device=torch.device('cuda')) -> torch.Tensor:
    vertex_num = adjacency_matrix.shape[0]
    assert adjacency_matrix.shape[1]==adjacency_matrix.shape[1]
    
    if device == torch.device('cuda'):
        adjacency_matrix = adjacency_matrix.to(device)
        adjacency_matrix[torch.arange(vertex_num), torch.arange(vertex_num)] = 0
        adjacency_matrix = torch.minimum(adjacency_matrix, adjacency_matrix.t())
        adjacency_matrix = torch.clamp(adjacency_matrix, min=0.0, max=max_threshold)
        distance_matrix = torch.clone(adjacency_matrix)
        try:
            from reranker.geodesic_distance import geodesic_distance
        except ImportError:
            for k in range(vertex_num):
                k_distance = distance_matrix[k].repeat(vertex_num,1) + distance_matrix[:,k].repeat(vertex_num, 1).transpose(0,1)
                distance_matrix = torch.minimum(distance_matrix, k_distance)
        else:
            distance_matrix = geodesic_distance.square_bellman_ford(5, adjacency_matrix)
            
        distance_matrix = torch.clamp(distance_matrix, min=0.0, max=max_threshold)
        distance_matrix = distance_matrix.to(torch.device('cpu'))

    elif device == torch.device('cpu'):
        adjacency_matrix[torch.arange(vertex_num), torch.arange(vertex_num)] = 0
        adjacency_matrix = torch.minimum(adjacency_matrix, adjacency_matrix.t())
        adjacency_matrix = torch.clamp(adjacency_matrix, min=0.0, max=max_threshold)
        distance_matrix = torch.clone(adjacency_matrix)
        for k in range(vertex_num):
            k_distance = distance_matrix[k].repeat(vertex_num,1) + distance_matrix[:,k].repeat(vertex_num, 1).transpose(0,1)
            distance_matrix = torch.minimum(distance_matrix, k_distance)
        distance_matrix = torch.clamp(distance_matrix, min=0.0, max=max_threshold)

    return distance_matrix


def transport_cost(all_num, original_dist, initial_rank, topk=4, max_threshold=50, confident_matrix=None, device=torch.device('cuda')):
    original_dist = original_dist.cpu()

    adjacency_connection = torch.zeros((all_num, all_num))
    adjacency_matrix = torch.zeros((all_num, all_num))
    if confident_matrix is not None:
        adjacency_connection = adjacency_connection + confident_matrix
        adjacency_connection[torch.repeat_interleave(torch.arange(all_num), topk), initial_rank[:,:topk].reshape(-1)] = 1
    else:
        adjacency_connection[torch.repeat_interleave(torch.arange(all_num), topk), initial_rank[:,:topk].reshape(-1)] = 1
    # adjacency_matrix = adjacency_connection*original_dist + (1-adjacency_connection)*max_threshold # implementation 1
    mutual_adjacency_connection = adjacency_connection * adjacency_connection.t()
    single_adjacency_connection = adjacency_connection + adjacency_connection.t() - 2*mutual_adjacency_connection
    adjacency_matrix = (mutual_adjacency_connection+single_adjacency_connection)*original_dist + (1-mutual_adjacency_connection-single_adjacency_connection)*max_threshold 

    print('Computing All Pairs of Shortest Paths')
    start_time = time.time()
    cost_matrix = all_pairs_shortest_paths(adjacency_matrix, max_threshold=max_threshold, device=device)
    end_time = time.time()
    print(f'===> Time cost {end_time-start_time}')

    return cost_matrix

from reranker.matrix_utils import matrix_utils
def sub_func(all_num, all_nonzero_index, sparse_query_feature, sparse_gallery_feature, cusum_index, nonzero_num, cost_matrix, i, device=torch.device('cuda'), epsilon_value=1):
    nonzero_index = all_nonzero_index[cusum_index[i]:cusum_index[i+1]]
    M = torch.empty((nonzero_num[i],all_num), device=device)
    M = matrix_utils.deepcopy(cost_matrix, nonzero_index.float(), M)
    
    K = torch.exp(-M/epsilon_value)
    u = torch.ones((nonzero_index.shape[0], all_num), device=device) / nonzero_index.shape[0]
    v = torch.ones((all_num, all_num), device=device) / all_num
    r = torch.tile(sparse_query_feature[i, nonzero_index], (all_num, 1)).t()
    c = sparse_gallery_feature.t()

    for iter in range(20):
        u = r / (torch.mm(K,v))
        v = c / (torch.mm(K.t(),u))
    tmp = torch.sum(torch.multiply(u, torch.mm(K*M, v)), dim=0)
    del M, K, u, v, r, c
    return tmp

def optimal_distance(cost_matrix, all_num, initial_rank, sparse_query_feature, sparse_gallery_feature, epsilon_value=1, iter_num=20, device=torch.device('cuda')):
    query_num = sparse_query_feature.shape[0]
    gallery_num = sparse_gallery_feature.shape[0]
    assert sparse_query_feature.shape[1]==sparse_gallery_feature.shape[1]

    print('Applying Geodesic Transport')

    # start_time = time.time()
    # run_time = 0
    # # Method 1
    # optimal_dist = np.zeros((query_num, gallery_num)).astype(np.float32)
    # for i in range(query_num):  
    #     nonzero_index = np.where(sparse_query_feature[i]!=0)[0]
    #     M = cost_matrix[nonzero_index].to(device)
    #     K = torch.exp(-M/epsilon_value)
    #     u = torch.ones((nonzero_index.shape[0], all_num), device=device) / nonzero_index.shape[0]
    #     v = torch.ones((all_num, all_num), device=device) / all_num
    #     r = torch.from_numpy(np.transpose(np.tile(sparse_query_feature[i, nonzero_index], (all_num, 1)))).to(device)
    #     c = torch.from_numpy(np.transpose(sparse_gallery_feature)).to(device)
    #     tmp_time = time.time()
    #     for iter in range(20):
    #         u = r / (torch.mm(K,v))
    #         v = c / (torch.mm(K.t(),u))
    #     tmp = torch.sum(torch.multiply(u, torch.mm(K*M, v)), dim=0)
    #     run_time = run_time+time.time()-tmp_time
    #     optimal_dist[i] = tmp.to(torch.device('cpu')).numpy()
    # optimal_dist = torch.from_numpy(optimal_dist).to(device=device)
    # # print(np.sum(tmp.to(torch.device('cpu')).numpy()-tmp1))

    # end_time = time.time()
    # print(f'===> Time cost {end_time-start_time}')
    # print(f'===> Run Time {run_time}')
    
    sparse_query_feature = torch.from_numpy(sparse_query_feature).to(device)
    sparse_gallery_feature = torch.from_numpy(sparse_gallery_feature).to(device)
    optimal_dist = torch.zeros((query_num, all_num), device=device)
    cost_matrix = cost_matrix.to(device=device)
    all_nonzero_index = torch.nonzero(sparse_query_feature, as_tuple=True)[1]
    nonzero_num = torch.count_nonzero(sparse_query_feature, dim=1)
    cusum_index = torch.zeros(query_num+1,dtype=torch.int32)
    cusum_index[1:] = torch.cumsum(nonzero_num, dim=0)
    run_time = 0
    for i in range(query_num):
        # nonzero_index = torch.nonzero(sparse_query_feature[i], as_tuple=True)[0]
        nonzero_index = all_nonzero_index[cusum_index[i]:cusum_index[i+1]]
        M = torch.empty((nonzero_num[i],all_num), device=device).copy_(cost_matrix[nonzero_index])
        torch.cuda.synchronize()
        
        tmp_time = time.time()
        K = torch.exp(-M/epsilon_value)
        u = torch.ones((nonzero_index.shape[0], all_num), device=device) / nonzero_index.shape[0]
        v = torch.ones((all_num, all_num), device=device) / all_num
        r = torch.tile(sparse_query_feature[i, nonzero_index], (all_num, 1)).t()
        c = sparse_gallery_feature.t()
        torch.cuda.synchronize()

        for iter in range(iter_num):
            u = r / (torch.mm(K,v))
            v = c / (torch.mm(K.t(),u))
        tmp = torch.sum(torch.multiply(u, torch.mm(K*M, v)), dim=0)
        torch.cuda.synchronize()
        run_time = run_time+time.time()-tmp_time
        optimal_dist[i] = tmp
    print(f'===> Time cost {run_time}')


    # torch.empty((nonzero_num[i],all_num), device=device)
    # Method 2
    # cost_matrix = cost_matrix.to(device)
    # optimal_dist = torch.zeros((query_num, all_num)).to(device)
    # sparse_query_feature = torch.from_numpy(sparse_query_feature).to(device)
    # sparse_gallery_feature = torch.from_numpy(sparse_gallery_feature).to(device)
    # K_matrix = torch.exp(-cost_matrix/epsilon_value)

    # step = 5
    # for batch in range(query_num//step):
    #     U_matrix = torch.ones((step*all_num, all_num), device=device)
    #     A_matrix = torch.repeat_interleave(sparse_query_feature[batch*step: (batch+1)*step], all_num, dim=0)
    #     V_matrix = torch.ones((step*all_num, all_num), device=device)
    #     B_matrix = sparse_gallery_feature.repeat(step, 1)

    #     for iter in range(iter_num):
    #         U_matrix = A_matrix / torch.mm(V_matrix, K_matrix.t())
    #         V_matrix = B_matrix / torch.mm(U_matrix, K_matrix)
    #     optimal_dist[batch*step: (batch+1)*step] = torch.sum(torch.multiply(torch.mm(U_matrix, K_matrix*cost_matrix), V_matrix), dim=1).reshape(step, all_num)
    
    
    return optimal_dist


def similarity_diffusion(all_num, confident_matrix, similarity_matrix, initial_rank, 
                  knn=10, mu=0.23, max_iter=10, lambda_value=2, threshold=1e-2, device=torch.device('cuda')):
    alpha = 1/(1+mu)

    confident_matrix = torch.from_numpy(confident_matrix).to(device=device)

    # Initialize Weight Matrix
    W_matrix = torch.zeros((all_num, all_num), device=device)
    W_matrix[torch.repeat_interleave(torch.arange(all_num), knn-1), initial_rank[:,1:knn].reshape(-1)] = \
        similarity_matrix[torch.repeat_interleave(torch.arange(all_num), knn-1), initial_rank[:,1:knn].reshape(-1)]
    # Enhance the weights of confident candidates
    W_matrix.add_(torch.multiply(confident_matrix, W_matrix), alpha=lambda_value-1)
    # Symmetry Operation
    W_matrix = (W_matrix + W_matrix.t()) / 2

    d_vector = torch.sqrt(torch.sum(W_matrix, dim=1, keepdim=True)).reciprocal_()
    D_matrix = torch.mul(d_vector, d_vector.t())
    S_matrix = torch.multiply(W_matrix, D_matrix)

    E_matrix = torch.eye(all_num, device=device) + S_matrix
    F_initial = torch.clone(W_matrix)
    F_matrix = torch.clone(F_initial)

    # Basic Iteration
    # A_matrix = torch.eye(all_num, device=device) - alpha*S_matrix
    for iter in range(max_iter):
        # residue = torch.sum(torch.pow(2*(1-alpha)*E_matrix - torch.mm(A_matrix, F_matrix) - torch.mm(F_matrix, A_matrix), 2))
        F_matrix = alpha*torch.mm(torch.mm(S_matrix, F_matrix), S_matrix.t()) + (1-alpha)*E_matrix
        # F_matrix = 0.5*alpha*torch.mm(F_matrix, S_matrix) + 0.5*alpha*torch.mm(S_matrix, F_matrix) + (1-alpha)*E_matrix
    # F_matrix  = F_matrix.to(torch.device('cpu')).numpy()

    return F_matrix


def wasserstein_barycenter(cost_matrix, feature_matrix, lambda_value, lambda_truncation, epsilon=0.1, iter_num=40, device=torch.device('cuda')):
    center_num = lambda_truncation.shape[0]
    dimension = feature_matrix.shape[1]
    
    cost_matrix = cost_matrix.to(device=device)
    feature_matrix = feature_matrix.to(device=device)
    lambda_value = lambda_value.to(device=device)
    lambda_truncation = lambda_truncation.to(device=device)

    K_matrix = torch.exp(-cost_matrix/epsilon)
    barycenter = torch.ones((center_num, dimension), device=device)
    u_vector = torch.ones_like(feature_matrix, device=device)
    v_vector = torch.ones_like(feature_matrix, device=device)
    prod_mask = torch.zeros((center_num, feature_matrix.shape[0]), device=device)
    prod_mask[torch.repeat_interleave(lambda_truncation), torch.arange(feature_matrix.shape[0])] = 1
    for iter in range(iter_num):
        v_vector = feature_matrix / torch.clamp(torch.mm(u_vector, K_matrix), min=1e-12)
        # barycenter = torch.mm(prod_mask, torch.pow(torch.mm(v_vector, K_matrix.t()), lambda_value.repeat(dimension, 1).t()))
        # barycenter = torch.exp(torch.mm(prod_mask, torch.log(torch.pow(torch.mm(v_vector, K_matrix.t()), lambda_value.repeat(dimension, 1).t()))))
        tmp = torch.pow(torch.mm(v_vector, K_matrix.t()), lambda_value.repeat(dimension, 1).t())
        rrr=0
        for i in range(center_num):
            barycenter[i] = torch.prod(tmp[rrr:rrr+lambda_truncation[i]], dim=0)
            rrr = rrr+lambda_truncation[i]
        u_vector = torch.repeat_interleave(barycenter, lambda_truncation, dim=0) / torch.clamp(torch.mm(v_vector, K_matrix.t()), min=1e-12)

    return barycenter.cpu()


def barycenter_refinery(all_num, sparse_feature, cost_matrix, initial_rank, all_features, epsilon=0.1, iter_num=40, k1=6, k2=100, lv=0, weight_mode='average', refine_mode='mean', confident_matrix=None):
    sparse_feature = torch.from_numpy(sparse_feature)

    mixed_feature = torch.zeros(all_num, all_num) 
    neighbor_matrix = torch.zeros((all_num, all_num))
    if confident_matrix is not None:
        neighbor_matrix = neighbor_matrix + confident_matrix
        neighbor_matrix[torch.repeat_interleave(torch.arange(all_num), k1), initial_rank[:,:k1].reshape(-1)] = 1
        reweight_matrix = neighbor_matrix + lv*confident_matrix
    else:
        neighbor_matrix[torch.repeat_interleave(torch.arange(all_num), k1), initial_rank[:,:k1].reshape(-1)] = 1
        reweight_matrix = neighbor_matrix

    feature_matrix = sparse_feature[torch.nonzero(neighbor_matrix, as_tuple=True)[1]]

    
    lambda_truncation = torch.sum(neighbor_matrix, dim=1, dtype=torch.int64)
    reweight_value = reweight_matrix[torch.nonzero(neighbor_matrix, as_tuple=True)]

    lambda_value = torch.ones(feature_matrix.shape[0])
    if weight_mode=='average':
        lambda_value = torch.ones(feature_matrix.shape[0])
    elif weight_mode=='cosine':
        lambda_value = torch.ones(feature_matrix.shape[0])
        tmp1 = all_features[torch.repeat_interleave(torch.arange(all_num), k1)]
        tmp2 = all_features[initial_rank[:, :k1].reshape(-1)]
        lambda_value = torch.sum(F.normalize(tmp1, p=2, dim=1) * F.normalize(tmp2, p=2, dim=1), dim=1)

    lambda_value = lambda_value * reweight_value # Enhance confident neighbors
    # Normalize
    normalize_mask = torch.zeros((lambda_truncation.shape[0], feature_matrix.shape[0]))
    normalize_mask[torch.repeat_interleave(lambda_truncation), torch.arange(feature_matrix.shape[0])] = 1
    normalize_sum = torch.mm(normalize_mask, lambda_value.view(-1, 1)).view(-1)
    lambda_value = lambda_value / torch.repeat_interleave(normalize_sum, lambda_truncation)

    mixed_feature = torch.zeros(all_num, all_num)
    truncate_feature = torch.zeros(all_num, all_num)
    if refine_mode=='mean':
        mixed_feature = torch.mm(neighbor_matrix, sparse_feature)
    elif refine_mode=='barycenter':
        mixed_feature = wasserstein_barycenter(cost_matrix, feature_matrix, lambda_value, lambda_truncation)
    
    sparse_rank = torch.argsort(mixed_feature, descending=True)
    truncate_feature[torch.repeat_interleave(torch.arange(all_num), k2), sparse_rank[:,:k2].reshape(-1)] = \
            mixed_feature[torch.repeat_interleave(torch.arange(all_num), k2), sparse_rank[:,:k2].reshape(-1)]

    mixed_feature = F.normalize(truncate_feature, p=1, dim=1)

    
    # markov_feature = torch.mm(mixed_feature, mixed_feature)
    # markov_rank = torch.argsort(markov_feature, descending=True)
    # truncate_markov_feature = torch.zeros(all_num, all_num)
    # truncate_markov_feature[torch.repeat_interleave(torch.arange(all_num), k2), markov_rank[:,:k2].reshape(-1)] = \
    #     markov_feature[torch.repeat_interleave(torch.arange(all_num), k2), markov_rank[:,:k2].reshape(-1)]
    # truncate_markov_feature = F.normalize(truncate_markov_feature, p=1, dim=1)

    # mixed_feature = torch.mm(truncate_markov_feature, truncate_markov_feature)
    # mixed_feature = F.normalize(mixed_feature, p=1, dim=1)

    return mixed_feature.numpy()


def average_barycenter_refinery(all_num, sparse_feature, cost_matrix, initial_rank, confident_matrix=None, epsilon=0.1, iter_num=10, k1=7, k2=80, k3=7, baryweight=0.3, lv=0, device=torch.device('cuda')):
    center_num = all_num
    lambda_value = 1 / k1
    sparse_feature = torch.from_numpy(sparse_feature).to(device=device)

    dimension = sparse_feature.shape[1]
    cost_matrix = cost_matrix.to(device=device)
    feature_matrix = sparse_feature[initial_rank[:,:k1].reshape(-1)]

    K_matrix = torch.exp(-cost_matrix/epsilon)
    barycenter = torch.ones((center_num, dimension), device=device)
    u_vector = torch.ones_like(feature_matrix, device=device) / k1
    v_vector = torch.ones_like(feature_matrix, device=device) / k1

    print('Applying Barycenter Refinery')
    start_time = time.time()
    tmp = torch.ones((center_num, dimension), device=device)
    for iter in range(10):
        v_vector = feature_matrix / torch.clamp(torch.mm(u_vector, K_matrix), min=1e-12)
        barycenter = torch.prod(torch.pow(torch.mm(v_vector, K_matrix.t()), exponent=lambda_value).reshape(center_num, k1, dimension), dim=1)
        # barycenter = torch.sum(torch.pow(torch.mm(v_vector, K_matrix.t()), exponent=lambda_value).reshape(center_num, k1, dimension), dim=1)
        u_vector = torch.repeat_interleave(barycenter, k1, dim=0) / torch.clamp(torch.mm(v_vector, K_matrix.t()), min=1e-12)
        # print(torch.sum(barycenter-tmp))
        tmp = torch.clone(barycenter)
    end_time = time.time()
    print(f'===> Time cost {end_time-start_time}')

    truncate_feature = torch.zeros((all_num, all_num), device=device)
    sparse_rank = torch.argsort(barycenter, descending=True)
    truncate_feature[torch.repeat_interleave(torch.arange(all_num), k2), sparse_rank[:,:k2].reshape(-1)] = \
            barycenter[torch.repeat_interleave(torch.arange(all_num), k2), sparse_rank[:,:k2].reshape(-1)]
    barycenter = F.normalize(truncate_feature, p=1, dim=1)

    # neighbor_matrix = torch.zeros((all_num, all_num), device=device)
    # neighbor_matrix[torch.repeat_interleave(torch.arange(all_num), k1), initial_rank[:,:k1].reshape(-1)] = 1/k1
    # mean_center = torch.mm(neighbor_matrix, sparse_feature)

    mean_center = torch.zeros(all_num, all_num, device=device)
    confident_matrix = torch.from_numpy(confident_matrix).to(device=device)
    for i in range(k3):
        mean_center.add_(sparse_feature[initial_rank[:,i],:])
    mean_center = mean_center / k3
    mean_center = (mean_center + lv*torch.mm(F.normalize(confident_matrix, p=1, dim=1), sparse_feature)) / (1 + lv)
    mean_center = F.normalize(mean_center, p=1, dim=1)


    mixed_feature = baryweight*barycenter + (1-baryweight)*mean_center
    mixed_feature = F.normalize(mixed_feature, p=1, dim=1)

    markov_feature = torch.mm(mixed_feature, mixed_feature)
    markov_rank = torch.argsort(markov_feature, descending=True)
    truncate_markov_feature = torch.zeros(all_num, all_num, device=device)
    truncate_markov_feature[torch.repeat_interleave(torch.arange(all_num), k2), markov_rank[:,:k2].reshape(-1)] = \
        markov_feature[torch.repeat_interleave(torch.arange(all_num), k2), markov_rank[:,:k2].reshape(-1)]
    truncate_markov_feature = F.normalize(truncate_markov_feature, p=1, dim=1)

    expand_sparse_feature = torch.mm(truncate_markov_feature, truncate_markov_feature)
    # expand_sparse_feature = truncate_markov_feature
    expand_sparse_feature = F.normalize(expand_sparse_feature, p=1, dim=1)
    mixed_feature = F.normalize(expand_sparse_feature, p=1, dim=1)
    

    return mixed_feature.cpu().numpy()


def markov_expansion(all_num, sparse_feature, initial_rank, 
                     k1=6, k2=20, iterations=1, lambda_value=2, confident_matrix=None, device=torch.device('cuda')):
    sparse_feature = torch.from_numpy(sparse_feature).to(device=device)
    mixed_feature = torch.zeros(all_num, all_num, device=device)

    print('Applying Markov Sparse Feature Expansion')
    start_time = time.time()
    if confident_matrix is not None:
        # confident matrix can be used to enhance the weight of the most confident neighbors when applying query expansion
        # otherwise, average the features according to the initial ranking list
        confident_matrix = torch.from_numpy(confident_matrix).to(device=device)
        for i in range(k1):
            mixed_feature.add_(sparse_feature[initial_rank[:,i],:])
        mixed_feature = mixed_feature / k1
        mixed_feature = (mixed_feature + lambda_value*torch.mm(F.normalize(confident_matrix, p=1, dim=1), sparse_feature)) / (1 + lambda_value)
    else:
        for i in range(k1):
            mixed_feature.add_(sparse_feature[initial_rank[:,i],:])
        mmixed_feature = mixed_feature / k1
    mixed_feature = F.normalize(mixed_feature, p=1, dim=1)

    # expand_sparse_feature = mixed_feature
    for iter in range(iterations):
        markov_feature = torch.mm(mixed_feature, mixed_feature)
        markov_rank = torch.argsort(markov_feature, descending=True)
        truncate_markov_feature = torch.zeros(all_num, all_num, device=device)
        truncate_markov_feature[torch.repeat_interleave(torch.arange(all_num), k2), markov_rank[:,:k2].reshape(-1)] = \
            markov_feature[torch.repeat_interleave(torch.arange(all_num), k2), markov_rank[:,:k2].reshape(-1)]
        truncate_markov_feature = F.normalize(truncate_markov_feature, p=1, dim=1)
        mixed_feature = mixed_feature + lambda_value*truncate_markov_feature

    expand_sparse_feature = torch.mm(truncate_markov_feature, truncate_markov_feature)
    expand_sparse_feature = F.normalize(expand_sparse_feature, p=1, dim=1)

    end_time = time.time()
    print(f'===> Time cost {end_time-start_time}')
    expand_sparse_feature = expand_sparse_feature.to(torch.device('cpu')).numpy()

    return expand_sparse_feature


def confident_candidate_expansion(all_num, initial_rank, k=5, device=torch.device('cuda')):
    '''
    Constructing the Confident Candidate Matrix
    '''
    confident_matrix = torch.zeros((all_num, all_num), device=device)
    confident_matrix[torch.repeat_interleave(torch.arange(all_num), k), initial_rank[:,:k].reshape(-1)] = 1
    confident_matrix = torch.multiply(confident_matrix, torch.transpose(confident_matrix, dim0=0, dim1=1))
    confident_matrix = confident_matrix.to(torch.device('cpu')).numpy()
    
    return confident_matrix


def candidate_expansion(all_num, initial_rank, k=20, neighbor_threshold=0, device=torch.device('cuda')):
    '''
    Constructing the Candidate Matrix
    '''
    candidate_matrix = torch.zeros((all_num, all_num), device=device)
    expansion_matrix = torch.zeros((all_num, all_num), device=device)
    print('Applying Candidate Expansion')
    start_time = time.time()
    # Initialize method 1
    # for i in range(all_num):
    #     candidate_matrix[i, initial_rank[i,:k+1]] = 1
    #     expansion_matrix[i, initial_rank[i,:int(np.around(k/2.))+1]] = 1
    # Initialize method 2
    # candidate_matrix[np.repeat(np.arange(all_num), k+1), initial_rank[:,:k+1].reshape(-1)] = 1
    # expansion_matrix[np.repeat(np.arange(all_num), int(np.around(k/2.0))+1), initial_rank[:,:int(np.around(k/2.0))+1].reshape(-1)] = 1
    candidate_matrix[torch.repeat_interleave(torch.arange(all_num), k+1), initial_rank[:,:k+1].reshape(-1)] = 1
    expansion_matrix[torch.repeat_interleave(torch.arange(all_num), k//2+1), initial_rank[:,:k//2+1].reshape(-1)] = 1
    candidate_matrix = torch.multiply(candidate_matrix, torch.transpose(candidate_matrix, dim0=0, dim1=1))
    expansion_matrix = torch.multiply(expansion_matrix, torch.transpose(expansion_matrix, dim0=0, dim1=1))
    init_candidate_matrix = torch.clone(candidate_matrix)

    for i in range(k+1):
        mask_expansion_matrix = expansion_matrix[initial_rank[:,i],:]
        mask_expansion_sum = torch.sum(mask_expansion_matrix, dim=1)
        intersect_matrix = torch.multiply(mask_expansion_matrix, init_candidate_matrix)
        intersect_sum = torch.sum(intersect_matrix, dim=1)
        intersect_sum = torch.multiply(intersect_sum, init_candidate_matrix[torch.arange(all_num), initial_rank[:,i]]) # set the values not exist in the initial candidate matrix to 0
        mask = torch.where(intersect_sum > 2/3*mask_expansion_sum)[0]
        candidate_matrix[mask] = candidate_matrix[mask] + mask_expansion_matrix[mask] - torch.multiply(candidate_matrix[mask], mask_expansion_matrix[mask])
    # Deal with the situations where there are only a few candidate neighbors
    candidate_num = torch.sum(candidate_matrix, dim=1)
    mask = torch.where(candidate_num<=neighbor_threshold)[0]
    initial_rank = initial_rank.to(device=device)
    candidate_matrix[torch.repeat_interleave(mask, k), initial_rank[mask, :k].reshape(-1)] = 1

    end_time = time.time()
    print(f'===> Time cost {end_time-start_time}')
    # candidate_matrix = candidate_matrix.to(torch.device('cpu')).numpy()
    
    return candidate_matrix


def Sparse_Feature_Construction(all_num, original_dist, initial_rank, 
                                confident_k=5, candidate_k=20, trans_k=20, sigma=0.4, beta=0, lambda_value=2, mu=0.23,
                                device=torch.device('cuda')):
    # similarity_matrix = np.exp(-original_dist/sigma**2)
    similarity_matrix = torch.exp(-original_dist/sigma**2)
    sparse_feature = torch.zeros((all_num, all_num), device=device)
    # sparse_feature = np.zeros_like(original_dist).astype(np.float32)
    confident_matrix = confident_candidate_expansion(all_num, initial_rank, k=confident_k)
    candidate_matrix = candidate_expansion(all_num, initial_rank, k=candidate_k)
    # semi_candidate_matrix = candidate_matrix - confident_matrix

    print('Applying Similarity Diffusion')
    start_time = time.time()
    # A_matrix = similarity_matrix
    # A_matrix = similarity_diffusion(all_num, confident_matrix, similarity_matrix, initial_rank, knn=2*confident_k) # MAC, R-MAC
    A_matrix = similarity_diffusion(all_num, confident_matrix, similarity_matrix, initial_rank, knn=confident_k, lambda_value=lambda_value, mu=mu) # R-GeM
    end_time = time.time()
    print(f'===> Time cost {end_time-start_time}')
    sparse_feature = torch.multiply(candidate_matrix, A_matrix)
    sparse_feature = F.normalize(sparse_feature, p=1, dim=1)
    sparse_feature = sparse_feature.to(torch.device('cpu')).numpy()

    return sparse_feature, confident_matrix, A_matrix


def ngt_reranking(query_features, 
                    gallery_features, 
                    metric = 'euclidean', 
                    mode = 'normal', 
                    k1=6,
                    k2=60,
                    k3=70,
                    k4=7,
                    k5=80,
                    lambda_value=0.3, 
                    lv=2, 
                    sigma=0.3, 
                    mu=0.23,
                    beta=0,
                    baryweight=0.1,
                    device=torch.device('cuda'),
                    mask=None):
    """Computes the reranking distance.

    Args:
        query_features (torch.Tensor): 2-D feature matrix.
        gallery_features (torch.Tensor): 2-D feature matrix.

    Returns:
        torch.Tensor: reranking distance matrix.
    """
    # Statistical information 
    query_num = query_features.shape[0]
    gallery_num = gallery_features.shape[0]

    all_num = query_num + gallery_num
    features = np.concatenate((query_features, gallery_features), axis=0).astype(np.float32)
    features = torch.from_numpy(features)
    # original_dist = compute_distance_matrix(features, features, metric).numpy()
    # original_dist = np.transpose(original_dist/np.max(original_dist, axis=0))
    # initial_rank = np.argsort(original_dist)

    original_dist = compute_distance_matrix(features, features, metric)
    original_dist = torch.transpose(original_dist/torch.max(original_dist, dim=0)[0], dim0=0, dim1=1)
    initial_rank = torch.argsort(original_dist)

    original_dist = original_dist.to(device=device)

    sparse_feature, confident_matrix, A_matrix = Sparse_Feature_Construction(all_num, original_dist, initial_rank, confident_k=k1, candidate_k=k2, trans_k=k3, beta=beta, sigma=sigma, mu=mu)
    
    cost_matrix = transport_cost(all_num, original_dist, initial_rank, topk=4)
    # cost_matrix = transport_cost(all_num, original_dist, initial_rank, topk=7)
    refine_cost = torch.ones(all_num, all_num)
    refine_cost[torch.arange(all_num), torch.arange(all_num)] = 0
    sparse_feature = average_barycenter_refinery(all_num, sparse_feature, refine_cost, initial_rank, confident_matrix=confident_matrix, k1=k4, k2=k5, k3=k4, baryweight=0.1, lv=0)
    # sparse_feature = markov_expansion(all_num, sparse_feature, initial_rank, k1=k4, k2=k5, confident_matrix=confident_matrix, lambda_value=lv)

    modified_dist = optimal_distance(cost_matrix, all_num, initial_rank, sparse_feature[:query_num], sparse_feature, iter_num=20) # iter_num=20 for more stable result
    # modified_dist = jensen_shannon_divergence(sparse_feature[:query_num], sparse_feature)
    # modified_dist = torch.from_numpy(Inverse_Jaccard_Distance(sparse_feature, (all_num,all_num))).to(torch.device('cuda'))[:query_num]
    # modified_dist = compute_distance_matrix(sparse_feature[:query_num], sparse_feature, 'cosine').to(torch.device('cuda'))
    # modified_dist = compute_distance_matrix(sparse_feature[:query_num], sparse_feature, 'euclidean').to(torch.device('cuda'))
    
    original_dist = original_dist[:query_num,:]
    final_dist = modified_dist*(1-lambda_value) + original_dist*lambda_value
    final_dist = final_dist[:query_num,query_num:]
    final_dist = final_dist.to(torch.device('cpu')).numpy()

    del original_dist, sparse_feature, modified_dist

    return final_dist