from __future__ import division

import numpy as np


def compute_average_precision(groundtruth, predictions):
	"""
	Computes average precision for a binary problem. This is based off of the
	PASCAL VOC implementation.
	Args:
		groundtruth (array-like): Binary vector indicating whether each sample
			is positive or negative.
		predictions (array-like): Contains scores for each sample.
	Returns:
		Average precision.
	"""
	predictions = np.asarray(predictions).squeeze()
	groundtruth = np.asarray(groundtruth, dtype=float).squeeze()
	if predictions.ndim != 1:
		raise ValueError('Predictions vector should be 1 dimensional.'
						 'For multiple labels, use `compute_multiple_aps`.')
	if groundtruth.ndim != 1:
		raise ValueError('Groundtruth vector should be 1 dimensional.'
						 'For multiple labels, use `compute_multiple_aps`.')

	sorted_indices = np.argsort(predictions)[::-1]
	predictions = predictions[sorted_indices]
	groundtruth = groundtruth[sorted_indices]
	# The false positives are all the negative groundtruth instances, since we
	# assume all instances were 'retrieved'. Ideally, these will be low scoring
	# and therefore in the end of the vector.
	false_positives = 1 - groundtruth

	tp = np.cumsum(groundtruth)      # tp[i] = # of positive examples up to i
	fp = np.cumsum(false_positives)  # fp[i] = # of false positives up to i

	num_positives = tp[-1]

	precisions = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
	recalls = tp / num_positives

	# Append end points of the precision recall curve.
	precisions = np.concatenate(([0.], precisions))
	recalls = np.concatenate(([0.], recalls))

	# Find points where prediction score changes.
	prediction_changes = set(np.where(predictions[1:] != predictions[:-1])[0] + 1)

	num_examples = predictions.shape[0]

	# Recall and scores always "change" at the first and last prediction.
	c = prediction_changes | set([0, num_examples])
	c = np.array(sorted(list(c)), dtype=np.int)

	precisions = precisions[c[1:]]

	# Set precisions[i] = max(precisions[j] for j >= i)
	# This is because (for j > i), recall[j] >= recall[i], so we can always use
	# a lower threshold to get the higher recall and higher precision at j.
	precisions = np.maximum.accumulate(precisions[::-1])[::-1]

	# this is calcualting the AUC, the change in recall of a segment with consistent
	# precision times that precision value.
	ap = np.sum((recalls[c[1:]] - recalls[c[:-1]]) * precisions)

	return ap


def compute_multiple_aps(groundtruth, predictions):
	"""Convenience function to compute APs for multiple labels.
	Args:
		groundtruth (np.array): Shape (num_samples, num_labels)
		predictions (np.array): Shape (num_samples, num_labels)
	Returns:
		aps_per_label (np.array, shape (num_labels,)): Contains APs for each
			label. NOTE: If a label does not have positive samples in the
			groundtruth, the AP is set to -1.
	"""
	predictions = np.asarray(predictions)
	groundtruth = np.asarray(groundtruth)
	if predictions.ndim != 2:
		raise ValueError('Predictions should be 2-dimensional,'
						 ' but has shape %s' % (predictions.shape, ))
	if groundtruth.ndim != 2:
		raise ValueError('Groundtruth should be 2-dimensional,'
						 ' but has shape %s' % (predictions.shape, ))

	num_labels = groundtruth.shape[1]
	aps = np.zeros(groundtruth.shape[1])
	for i in range(num_labels):
		if not groundtruth[:, i].any():
			# print('WARNING: No groundtruth for label: %s' % i)
			aps[i] = -1
		else:
			aps[i] = compute_average_precision(groundtruth[:, i],
											   predictions[:, i])
	return aps

def compute_mAP(aps, labels=None):
	if labels is not None:
		for i, ap in enumerate(aps):
			print(f'{labels[i]} AP: {ap}')
	return np.mean([i for i in aps if i != -1])


def run_test():
	# both pred and gt are of shape [sample_num, cls labels]
	# pred_1 = np.array([[0.0], [0.0], [0.0]])
	# gt_1 = np.array([[0], [0], [1]])
	# aps_1 = compute_multiple_aps(gt_1, pred_1)
	# map_1 = np.mean(aps_1)   # should be 0.333
	# print(map_1)

	pred_2 = np.array([[0.0, 0.3, 0.7, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]])
	gt_2 = np.array([[0, 1, 1, 1], [0, 0, 1, 0], [1, 1, 0, 0]])
	aps_2 = compute_multiple_aps(gt_2, pred_2)
	map_2 = np.mean(aps_2)  # should be 0.667
	print(map_2)


if __name__ == '__main__':
	run_test()