import torch
from collections import defaultdict
from sklearn.metrics import f1_score, recall_score, precision_score
import numpy as np


class Metric(object):
	def __init__(self):
		self._data = defaultdict(lambda: [])

	def update(self, loss, pred, label):
		_n = pred.shape[0]
		self._data['_n'].append(_n)
		pred = pred.detach().cpu().numpy()
		pred = np.argmax(pred, 1)

		self._data['loss'].append(loss.detach().cpu().numpy() * _n)
		self._data['pred'].append(pred)
		self._data['label'].append(label.detach().cpu().numpy())

	def dict(self):
		return {
			'loss': self.loss(),
			'f1': self.f1(),
			'precision': self.precision(),
			'recall': self.recall(),
		}

	@property
	def count(self):
		return np.sum(self._data['_n'])

	def clear(self):
		self._data = defaultdict(lambda: [])

	def loss(self):
		return np.sum(self._data['loss']) / self.count

	@property
	def pred(self):
		return np.concatenate(self._data['pred'])

	@property
	def label(self):
		return np.concatenate(self._data['label'])

	def precision(self):
		return precision_score(self.label, self.pred, average='macro')

	def f1(self):
		return f1_score(self.label, self.pred, average='macro')

	def recall(self):
		return recall_score(self.label, self.pred, average='macro')

