import numpy as np

from sklearn.metrics import euclidean_distances
from joblib import Parallel, delayed
from sklearn.utils import check_array

from pyemd import emd


class WordMoversDistance(object):
    
    
    def __init__(self, embeddings, n_jobs=-1, verbose=0):
        
        self.embeddings = embeddings
        self.n_jobs = n_jobs
        self.verbose = verbose
        
        self.source_tokens = None
        self.n_samples = None
        
    def fit(self, source_ids, source_tokens):
        """
        X is a list of list of token IDs of documents
        """
        
        self.source_ids = source_ids
        self.source_tokens = source_tokens
        self.n_samples = len(self.source_tokens)
        
    def _wmd(self, row, target_tokens):
        
        src_tokens = self.source_tokens[row]
        
        union_idx = np.sort(np.union1d(target_tokens, src_tokens))
        token_map = {tok: i for i, tok in enumerate(union_idx.ravel())}
        token_embeddings = self.embeddings[union_idx]
        
        embed_dist = euclidean_distances(token_embeddings)
        
        source_hist = np.array([0.0] * len(union_idx))
        target_hist = np.array([0.0] * len(union_idx))
        
        for tok in src_tokens:
            source_hist[token_map[tok]] += 1.0
        for tok in target_tokens:
            target_hist[token_map[tok]] += 1.0
            
        source_hist = source_hist / np.sum(source_hist)
        target_hist = target_hist / np.sum(target_hist)
        
        source_hist = source_hist.astype(np.float64)
        target_hist = target_hist.astype(np.float64)
        embed_dist = embed_dist.astype(np.float64)
        
        pairdist = emd(source_hist, target_hist, embed_dist)
        return pairdist
    
    
    def _pairwise_wmd(self, target_tokens):
        
        dist = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
            delayed(self._wmd)(i, target_tokens)
            for i in range(self.n_samples)
        )
        
        return np.array(dist)
    
    
    def predict(self, target_tokens):
        
        dist = self._pairwise_wmd(target_tokens)
        return dist