import torch
from tqdm import tqdm
from pytorch_metric_learning.distances import BaseDistance as PmlDistance


class BaseDivergence(PmlDistance):
    def __init__(self, phi=None, mat_batch_size=None, **kwargs):
        normalize_embeddings = kwargs.get('normalize_embeddings', False)
        kwargs['normalize_embeddings'] = normalize_embeddings
        super().__init__(**kwargs)
        self.phi = phi
        self.mat_batch_size = mat_batch_size

    def compute_mat(self, query_emb, ref_emb):
        ''' Must return a matrix where mat[j,k] represents 
        the distance/similarity between query_emb[j] and ref_emb[k]
        '''
        raise NotImplementedError

    def pairwise_distance(self, query_emb, ref_emb):
        ''' Must return a tensor where output[j] represents 
        the distance/similarity between query_emb[j] and ref_emb[j]
        '''
        raise NotImplementedError

    def compute_full_mat(self, query_emb, ref_emb):
        ''' Compute distance/similarity matrix where mat[j,k] represents
        d(query_emb[j], ref_emb[k]). However this implementation must handle 
        entire dataset sizes, where standard tensor approach does not fit
        in memory. Workarounds include PyKeops or iterating over batches
        '''
        raise NotImplementedError

    def batch_compute_mat(self, query_emb, ref_emb):
        ''' Compute Q x R pairwise divergence matrix between queries and
        reference embeddings. This is the default workaround for
        compute_full_mat, and operates by computing in batches across
        the query embedding dimension (Q)
        '''
        batch_size = self.mat_batch_size if self.mat_batch_size else 64
        with torch.no_grad():
            M, _ = query_emb.shape
            N, _ = ref_emb.shape

            dist_mat = torch.zeros((M, N), device=query_emb.device)
            for b in tqdm(range(0, M, batch_size)):
                query_batch = query_emb[b:b + batch_size]
                actual_batch_size = query_batch.shape[0]

                dist = self.compute_mat(query_batch, ref_emb)
                dist_mat[b:b + actual_batch_size, :] = dist.clone()
        return dist_mat
