import logging
from typing import List

import torch
from sklearn.metrics.pairwise import euclidean_distances

logger = logging.getLogger('custom')


def compute_precision_recall(data, mod):
    """
    https://arxiv.org/abs/1904.06991
    """
    results = {}
    for k in [2, 3]:
        cur_results = _run(data=data, mod=mod, k=k)
        results[f'k_{k}'] = cur_results
        logger.info(f'\nk={k}:')
        for k1, v1 in cur_results.items():
            logger.info(f' {k1}:')
            for k2, v2 in v1.items():
                logger.info(f'   {k2}: {v2:.1f}%')
    return results


def _run(data, mod, k):
    results = {}
    if mod == 'x1':
        results['x1_emb'] = __run(data['x1_emb'], data['x1|x2_emb'], k)
        # f'x1_emb_ft': __run(data['x1_emb_ft'], data['x1|x2_emb_ft'], k)
    elif mod == 'x2':
        results['x2'] = __run(data['x2'], data['x2|x1'], k)
    else:
        raise ValueError
    return results


def __run(truth, approximation, k):
    data = dict(
        precision=_f([approximation, truth], k),
        recall=_f([truth, approximation], k)
    )
    p, r = data['precision'], data['recall']
    if p == 0 and r == 0:
        f1 = 0  # Avoid division by zero
    else:
        f1 = 2 * (p * r) / (p + r)
    data['f1'] = f1
    return data


def _f(phi: List[torch.Tensor], k):
    """ Eq. (1) from the paper

    :param phi: shape N x D across all entries
        phi[0]: query
        phi[1]: support (later represented by hyperspheres)
        Precision: phi = [approximation, truth]
        Recall: phi = [truth, approximation]
    :param k: k-nearest neighbors for precision/recall measure
    """
    # Compute hyperspheres spanned by k-nearest neighbors
    d = euclidean_distances(phi[1], phi[1])  # N x N
    idx = d.argsort()[:, k + 1]  # +1 for the sample itself
    # Euclidean distances between any sample and its kNN
    radius = (phi[1] - phi[1][idx]).square().sum(-1).sqrt()  # N,

    # For any sample in support, compute Euclidean distance to any sample in query
    d = torch.tensor(euclidean_distances(phi[0], phi[1]))  # N_s x N_q

    # Check whether distance is smaller than radius
    d = torch.le(d, radius)  # N_s x N_q
    d = d.sum(dim=1)  # Count hits, shape N_s,
    d = d > 0  # One hit is sufficient
    acc = d.float().mean().item()  # Average number of hits
    acc *= 100

    return acc
