#!/usr/bin/env python

"""
Nonconformity functions.
"""

from __future__ import division
import torch
import abc
import numpy as np
import sklearn.base
from nonconformist.base import ClassifierAdapter, RegressorAdapter

if torch.cuda.is_available():
	device = "cuda:0"
else:
	device = "cpu"
# -----------------------------------------------------------------------------
# Error functions
# -----------------------------------------------------------------------------


class RegressionErrFunc(object):
	"""Base class for ordinary_regression_task model error functions.
	"""

	__metaclass__ = abc.ABCMeta

	def __init__(self):
		super(RegressionErrFunc, self).__init__()

	@abc.abstractmethod
	def apply(self, prediction, y):#, norm=None, beta=0):
		"""Apply the nonconformity function.

		Parameters
		----------
		prediction : numpy array of shape [n_samples, n_classes]
			Class probability estimates for each sample.

		y : numpy array of shape [n_samples]
			True output labels of each sample.

		Returns
		-------
		nc : numpy array of shape [n_samples]
			Nonconformity scores of the samples.
		"""
		pass

	@abc.abstractmethod
	def apply_inverse(self, nc, significance):#, norm=None, beta=0):
		"""Apply the inverse of the nonconformity function (i.e.,
		calculate prediction interval).

		Parameters
		----------
		nc : numpy array of shape [n_calibration_samples]
			Nonconformity scores obtained for conformal predictor.

		significance : float
			Significance level (0, 1).

		Returns
		-------
		interval : numpy array of shape [n_samples, 2]
			Minimum and maximum interval boundaries for each prediction.
		"""
		pass



class AbsErrorErrFunc(RegressionErrFunc):
	"""Calculates absolute error nonconformity for ordinary_regression_task problems.

		For each correct output in ``y``, nonconformity is defined as

		.. math::
			| y_i - \hat{y}_i |
	"""

	def __init__(self):
		super(AbsErrorErrFunc, self).__init__()

	def apply(self, prediction, y):
		return np.abs(prediction - y)

	def apply_inverse(self, nc, significance):
		nc = np.sort(nc)[::-1]
		border = int(np.floor(significance * (nc.size + 1))) - 1
		# TODO: should probably warn against too few calibration examples
		border = min(max(border, 0), nc.size - 1)
		return np.vstack([nc[border], nc[border]])


class SignErrorErrFunc(RegressionErrFunc):
	"""Calculates signed error nonconformity for ordinary_regression_task problems.

	For each correct output in ``y``, nonconformity is defined as

	.. math::
		y_i - \hat{y}_i

	References
	----------
	.. [1] Linusson, Henrik, Ulf Johansson, and Tuve Lofstrom.
		Signed-error conformal ordinary_regression_task. Pacific-Asia Conference on Knowledge
		Discovery and Data Mining. Springer International Publishing, 2014.
	"""

	def __init__(self):
		super(SignErrorErrFunc, self).__init__()

	def apply(self, prediction, y):
		return (prediction - y)

	def apply_inverse(self, nc, significance):
		
		err_high = -nc
		err_low = nc
		
		err_high = np.reshape(err_high, (nc.shape[0],1))
		err_low = np.reshape(err_low, (nc.shape[0],1))
		
		nc = np.concatenate((err_low,err_high),1)
		
		nc = np.sort(nc,0)
		index = int(np.ceil((1 - significance / 2) * (nc.shape[0] + 1))) - 1
		index = min(max(index, 0), nc.shape[0] - 1)
		return np.vstack([nc[index,0], nc[index,1]])

# CQR symmetric error function
class QuantileRegErrFunc(RegressionErrFunc):
	"""Calculates conformalized quantile ordinary_regression_task error.
	
	For each correct output in ``y``, nonconformity is defined as
	
	.. math::
		max{\hat{q}_low - y, y - \hat{q}_high}
	
	"""
	def __init__(self):
		super(QuantileRegErrFunc, self).__init__()

	def apply(self, prediction, y):
		y_lower = prediction[:,0]
		y_upper = prediction[:,-1]
		error_low = y_lower - y
		error_high = y - y_upper
		err = np.maximum(error_high,error_low)
		return err

	def apply_inverse(self, nc, significance):   # nc is nonconformity score
		nc = np.sort(nc,0)
		index = int(np.ceil((1 - significance) * (nc.shape[0] + 1))) - 1
		index = min(max(index, 0), nc.shape[0] - 1)
		return np.vstack([nc[index], nc[index]])

# CQR asymmetric error function 
class QuantileRegAsymmetricErrFunc(RegressionErrFunc):
	"""Calculates conformalized quantile ordinary_regression_task asymmetric error function.
	
	For each correct output in ``y``, nonconformity is defined as
	
	.. math::
		E_low = \hat{q}_low - y
		E_high = y - \hat{q}_high
	
	"""
	def __init__(self):
		super(QuantileRegAsymmetricErrFunc, self).__init__()

	def apply(self, prediction, y):
		y_lower = prediction[:,0]
		y_upper = prediction[:,-1]
		
		error_high = y - y_upper 
		error_low = y_lower - y
		
		err_high = np.reshape(error_high, (y_upper.shape[0],1))
		err_low = np.reshape(error_low, (y_lower.shape[0],1))

		return np.concatenate((err_low,err_high),1)

	def apply_inverse(self, nc, significance):
		nc = np.sort(nc,0)
		index = int(np.ceil((1 - significance / 2) * (nc.shape[0] + 1))) - 1
		index = min(max(index, 0), nc.shape[0] - 1)
		return np.vstack([nc[index,0], nc[index,1]])
	
# -----------------------------------------------------------------------------
# Base nonconformity scorer
# -----------------------------------------------------------------------------
class BaseScorer(sklearn.base.BaseEstimator):
	__metaclass__ = abc.ABCMeta

	def __init__(self):
		super(BaseScorer, self).__init__()

	@abc.abstractmethod
	def fit(self, x, y):
		pass

	@abc.abstractmethod
	def score(self, x, y=None):
		pass


class RegressorNormalizer(BaseScorer):
	def __init__(self, base_model, normalizer_model, err_func):
		super(RegressorNormalizer, self).__init__()
		self.base_model = base_model
		self.normalizer_model = normalizer_model
		self.err_func = err_func

	def fit(self, x, y):
		residual_prediction = self.base_model.predict(x)
		residual_error = np.abs(self.err_func.apply(residual_prediction, y))

		######################################################################
		# Optional: use logarithmic function as in the original implementation
		#
		# CODE:
		# residual_error += 0.00001 # Add small term to avoid log(0)
		# log_err = np.log(residual_error)
		######################################################################

		log_err = residual_error
		self.normalizer_model.fit(x, log_err)

	def score(self, x, y=None):

		######################################################################
		# Optional: use logarithmic function as in the original implementation
		#
		# CODE:
		# norm = np.exp(self.normalizer_model.predict(x))
		######################################################################

		norm = np.abs(self.normalizer_model.predict(x))
		return norm


class BaseModelNc(BaseScorer):
	"""Base class for nonconformity scorers based on an underlying model.

	Parameters
	----------
	model : ClassifierAdapter or RegressorAdapter
		Underlying classification_task model used for calculating nonconformity
		scores.

	err_func : ClassificationErrFunc or RegressionErrFunc
		Error function object.

	normalizer : BaseScorer
		Normalization model.

	beta : float
		Normalization smoothing parameter. As the beta-value increases,
		the normalized nonconformity function approaches a non-normalized
		equivalent.
	"""
	def __init__(self, model, err_func, normalizer=None, beta=1e-6):
		super(BaseModelNc, self).__init__()
		self.err_func = err_func
		self.model = model
		self.normalizer = normalizer
		self.beta = beta

		# If we use sklearn.base.clone (e.g., during cross-validation),
		# object references get jumbled, so we need to make sure that the
		# normalizer has a reference to the proper model adapter, if applicable.
		if (self.normalizer is not None and
			hasattr(self.normalizer, 'base_model')):
			self.normalizer.base_model = self.model

		self.last_x, self.last_y = None, None
		self.last_prediction = None
		self.clean = False

	def fit(self, x, y, dataset_name, seed):
		"""Fits the underlying model of the nonconformity scorer.

		Parameters
		----------
		x : numpy array of shape [n_samples, n_features]
			Inputs of examples for fitting the underlying model.

		y : numpy array of shape [n_samples]
			Outputs of examples for fitting the underlying model.

		Returns
		-------
		None
		"""
		self.model.fit(x, y, dataset_name, seed)
		if self.normalizer is not None:
			self.normalizer.fit(x, y)
		self.clean = False

	def score(self, x, y=None, model=None, bias=0):
		"""Calculates the nonconformity score of a set of samples.

		Parameters
		----------
		x : numpy array of shape [n_samples, n_features]
			Inputs of examples for which to calculate a nonconformity score.

		y : numpy array of shape [n_samples]
			Outputs of examples for which to calculate a nonconformity score.

		Returns
		-------
		nc : numpy array of shape [n_samples]
			Nonconformity scores of samples.
		"""
		if model == None:
			prediction = self.model.predict(x)
		else:
			model.eval()
			test_preds = model(torch.from_numpy(x).to(device).requires_grad_(False)).cpu().detach().numpy()
			test_preds[:, 0] = np.min(test_preds, axis=1)
			test_preds[:, 1] = np.max(test_preds, axis=1)
			prediction = test_preds  # used to be: prediction = self.model.predict(x)
		prediction = prediction + bias   # a bias term added manually
		n_test = x.shape[0]
		if self.normalizer is not None:
			norm = self.normalizer.score(x) + self.beta
		else:
			norm = np.ones(n_test)
		if prediction.ndim > 1:
			ret_val = self.err_func.apply(prediction, y)
		else:
			ret_val = self.err_func.apply(prediction, y) / norm
		return ret_val


# -----------------------------------------------------------------------------
# Classification nonconformity scorers
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# Regression nonconformity scorers
# -----------------------------------------------------------------------------
class RegressorNc(BaseModelNc):
	"""Nonconformity scorer using an underlying ordinary_regression_task model.

	Parameters
	----------
	model : RegressorAdapter
		Underlying ordinary_regression_task model used for calculating nonconformity scores.

	err_func : RegressionErrFunc
		Error function object.

	normalizer : BaseScorer
		Normalization model.

	beta : float
		Normalization smoothing parameter. As the beta-value increases,
		the normalized nonconformity function approaches a non-normalized
		equivalent.

	Attributes
	----------
	model : RegressorAdapter
		Underlying model object.

	err_func : RegressionErrFunc
		Scorer function used to calculate nonconformity scores.

	See also
	--------
	ProbEstClassifierNc, NormalizedRegressorNc
	"""
	def __init__(self,
				 model,
				 err_func=AbsErrorErrFunc(),
				 normalizer=None,
				 beta=1e-6):
		super(RegressorNc, self).__init__(model,
										  err_func,
										  normalizer,
										  beta)

	def predict(self, x, nc, significance=None, model=None, bias=0):
		"""Constructs prediction intervals for a set of test examples.

		Predicts the output of each test pattern using the underlying model,
		and applies the (partial) inverse nonconformity function to each
		prediction, resulting in a prediction interval for each test pattern.

		Parameters
		----------
		x : numpy array of shape [n_samples, n_features]
			Inputs of patters for which to predict output values.

		significance : float
			Significance level (maximum allowed error rate) of predictions.
			Should be a float between 0 and 1. If ``None``, then intervals for
			all significance levels (0.01, 0.02, ..., 0.99) are output in a
			3d-matrix.

		Returns
		-------
		p : numpy array of shape [n_samples, 2] or [n_samples, 2, 99]
			If significance is ``None``, then p contains the interval (minimum
			and maximum boundaries) for each test pattern, and each significance
			level (0.01, 0.02, ..., 0.99). If significance is a float between
			0 and 1, then p contains the prediction intervals (minimum and
			maximum	boundaries) for the set of test patterns at the chosen
			significance level.
		"""
		n_test = x.shape[0]
		if model == None:
			prediction = self.model.predict(x)
		else:
			model.eval()
			test_preds = model(torch.from_numpy(x).to(device).requires_grad_(False)).cpu().detach().numpy()
			test_preds[:, 0] = np.min(test_preds, axis=1)
			test_preds[:, 1] = np.max(test_preds, axis=1)
			prediction = test_preds         #  used to be: prediction = self.model.predict(x)
		prediction = prediction + bias     # a bias term added manually
		if self.normalizer is not None:
			norm = self.normalizer.score(x) + self.beta
		else:
			norm = np.ones(n_test)

		if significance:
			intervals = np.zeros((x.shape[0], 2))
			err_dist = self.err_func.apply_inverse(nc, significance)   #nc is noncomformity score
			err_dist = np.hstack([err_dist] * n_test)
			if prediction.ndim > 1: # CQR
				intervals[:, 0] = prediction[:,0] - err_dist[0, :]
				intervals[:, 1] = prediction[:,-1] + err_dist[1, :]   #only for 1*n or n*1 vector
			else: # regular conformal prediction
				err_dist *= norm
				intervals[:, 0] = prediction - err_dist[0, :]
				intervals[:, 1] = prediction + err_dist[1, :]

			return intervals
		else: # Not tested for CQR
			significance = np.arange(0.01, 1.0, 0.01)
			intervals = np.zeros((x.shape[0], 2, significance.size))

			for i, s in enumerate(significance):
				err_dist = self.err_func.apply_inverse(nc, s)
				err_dist = np.hstack([err_dist] * n_test)
				err_dist *= norm

				intervals[:, 0, i] = prediction - err_dist[0, :]
				intervals[:, 1, i] = prediction + err_dist[0, :]

			return intervals
