import numpy as np
from sklearn.metrics import pairwise_distances
from scipy.special import loggamma

from concept_processing.labels import calc_log_evidence_ratio

def find_pairwise_nearest(prox_mtx):
    neighbours = np.empty(prox_mtx.shape[0], dtype=int)
    for i, row in enumerate(prox_mtx):
        filter_ = np.ones(prox_mtx.shape[1], dtype=bool)
        filter_[i] = False
        nearest_id = np.argmin(row[filter_])
        if nearest_id >= i:
            nearest_id += 1
        neighbours[i] = nearest_id
    return neighbours
    
    
def display_best_closest_neighbours(neighbours, prox_mtx, concepts, C=None, K = None):
    if C is None:
        C = len(concepts)
    if K is None:
        K = C-1
    mins = np.empty(C)
    for i in range(C):
        #print(i,neighbours[i])
        mins[i] = prox_mtx[i,neighbours[i]]
    # get the first K smallest distances
    bestK = np.argpartition(mins, K)[:K]
    bestKsorted = bestK[np.argsort(mins[bestK])]
    for index in bestKsorted:
        partnerindex = neighbours[index]
        #if partnerindex < index:
        #    continue
        concept = concepts[index]
        partnerconcept = concepts[partnerindex]
        dist = mins[index]
        print(f"dist: {dist}:\n\t{concept}\n\t{partnerconcept}")
        
        
def calc_prox_mtx_embedding(embeds, metric='manhattan', p=768):
    # embedding distances    
    if metric == 'manhattan':
        prox_mtx_embedding = pairwise_distances(embeds, metric=metric)
    elif metric == 'cosine':
        prox_mtx_embedding = pairwise_distances(embeds, metric=metric)
    elif metric == 'minkowski':
        ## for minkowski we have to clean up the non-finite values
        prox_mtx_embedding = pairwise_distances(embeds, metric=metric, p=768)
        dummy_dist = np.max(prox_mtx_embedding[np.isfinite(prox_mtx_embedding)])*1.1
        prox_mtx_embedding[~np.isfinite(prox_mtx_embedding)] = dummy_dist 
    elif metric == 'chebyshev':
        prox_mtx_embedding = pairwise_distances(embeds, metric=metric)
    return prox_mtx_embedding
    
def calc_prox_mtx_labels_evidence_ratio(label_counts_mtx, alpha=0.5):
    prox_mtx_labels = np.array([[
        np.exp(calc_log_evidence_ratio(rowi, rowj, alpha)) for rowj in label_counts_mtx]
               for rowi in label_counts_mtx ])
    return prox_mtx_labels
    


def calc_prox_mtx_labels(label_counts_mtx, alpha=0.5, labelmetric='evidence_ratio'):
    if labelmetric == 'evidence_ratio':
        return calc_prox_mtx_labels_evidence_ratio(label_counts_mtx, alpha)
    elif labelmetric == 'beta_ratio':
        return calc_prox_mtx_labels_beta_ratio(label_counts_mtx, alpha)
    else:
        raise ValueError(f"Unrecognised label metric {labelmetric}")
        
def calc_prox_mtx_labels_beta_ratio(label_counts_mtx, alpha=0.5):
    prox_mtx_labels = np.array([[
        np.exp(calc_log_beta_ratio(rowi, rowj, alpha)) for rowj in label_counts_mtx]
               for rowi in label_counts_mtx ])
    return prox_mtx_labels


def calc_log_beta_ratio(n, m, alpha):
    log_ratio = calc_log_multidim_beta(n, alpha)
    log_ratio += calc_log_multidim_beta(m, alpha)
    log_ratio -= calc_log_multidim_beta(n+m, alpha)
    return log_ratio
    
def calc_log_multidim_beta(n, alpha):
    return np.sum(loggamma(n+alpha)) - loggamma(np.sum(n+alpha))
    

