"""Get batch predictions given ensemlbe model outputs.
"""
import abc
import itertools

import numpy as np
import scipy as sp
import scipy.stats as sst

import multiprocessing as mp

SUPPORTED_MLES = {}
EPS = 1e-7


class MLEMeta(abc.ABCMeta):
  """Metaclass for keeping track of all MLE models."""

  def __new__(cls, name, bases, attr):
    cls = super().__new__(cls, name, bases, attr)
    if cls.__name__ != 'BaseMLE':
      SUPPORTED_MLES[cls.__name__.lower()] = cls
    return cls


class BaseMLE(metaclass=MLEMeta):
  """Wrapper class for all MLE Models.

  All MLE models must inherit.

  All inheriting class names are automatically added to `SUPPORTED_MLES`.
  """

  def __init__(self, config, params, error_metric):
    """Initialize class with paraneters.

    Args:
      config: dict configuation parameters for this distribution
      params: list list of parameters for this dirtibtuion [zprob, n, p]
      error_metric: str error metric to optimize for
    """
    self.config = config
    self.params = params
    self.num_samples = config.get('num_samples', 10000)
    self.quantile = config.get('quantile', 0.5)
    self.min_val = config.get('min_val', 1.0)
    self.error_metric = error_metric

  def sample(self, num_samples):
    """Get samples from MLE."""
    pass

  def predict(self):
    """Predict the statistic that optimises error metric."""
    samples = self.sample(num_samples=self.num_samples)
    if self.error_metric == 'MSE':
      return np.mean(samples)
    if self.error_metric == 'MAPE':
      samples = samples[samples >= self.min_val]
      beta = -1
    elif self.error_metric == 'MAE':
      beta = 0
    elif self.error_metric == 'RE':
      beta = 1
    elif self.error_metric == 'QUANTILE':
      return np.percentile(samples, q=100 * self.quantile)
    else:
      raise NotImplementedError('error_metric not implemented')
    if len(samples) == 0:
      return 0
    points = samples
    pmf = np.power(np.abs(samples + EPS), beta)
    spm = np.sum(pmf)
    if spm <= 0.0 or np.isnan(spm):
      pmf = np.ones(shape=pmf.shape)
      pmf = pmf / np.sum(pmf)
    else:
      pmf /= spm
    rv = sp.stats.rv_discrete(values=(points, pmf))
    return rv.median()

  @classmethod
  def inverse_link_function(cls, model_outputs):
    """Process model outputs and also the mean from the distribution."""
    pass


def _softplus(xin):
  """Safe soft plus implementation in numpy."""
  return np.log(1 + np.exp(-np.abs(xin))) + np.maximum(xin, 0)


def _sigmoid(xin):
  """Sigmoid in numpy."""
  return 1.0 / (1.0 + np.exp(-xin))


def _softmax(xin, axis=1):
  """Softmax on the given axis."""
  return sp.special.softmax(xin, axis=axis)


def _convert_to_positive(num):
  out = np.where(num > 0, num + 1.0, 1.0 / (1.0 - num))
  return out


class ZeroInflatedNBinomMLE(BaseMLE):
  """Tools for inference from Zero Infalted NBinom Distribution."""

  def __init__(self, config, params, error_metric):
    """Initialize class with paraneters.

    Args:
      config: dict configuation parameters for this distribution
      params: list list of parameters for this dirtibtuion [zprob, n, p]
      error_metric: str error metric to optimize for
    """
    self.params = params
    self.max_val = config.get('max_val', 10000)
    self.error_metric = error_metric
    super().__init__(config, params, error_metric)

  def sample(self, num_samples):
    """Sample from distribution.

    Args:
      num_samples: int number of samples

    Returns:
      #num_samples samples from the zero-inflated nbinom.
    """
    nbinom_samples = sst.nbinom.rvs(
        n=self.params[1], p=self.params[2], size=num_samples)
    binom_samples = sst.bernoulli.rvs(p=1.0 - self.params[0], size=num_samples)
    return binom_samples * nbinom_samples

  def predict(self):
    """Returns the statistic optimizing error metric."""
    rv = self._create_rv()
    return rv.median()

  @classmethod
  def inverse_link_function(cls, model_outputs, config=None):
    """Process model outputs and also the mean from the distribution.

    Args:
      model_outputs: np array array #examples * #num_outputs

    Returns:
      params which are input to distribution and
      mean predictions from MLE.
    """
    params = np.zeros(shape=model_outputs.shape)
    params[:, 0] = _sigmoid(model_outputs[:, 0])
    params[:, 1] = _convert_to_positive(model_outputs[:, 1])
    params[:, 2] = 1 - _sigmoid(model_outputs[:, 2])
    mean_out = (1 - params[:, 0]) * params[:, 1] * (1 -
                                                    params[:, 2]) / params[:, 2]
    return params, mean_out

  def _create_rv(self):
    """Create RV for inference."""
    if self.error_metric == 'MAPE':
      points = 0. + np.arange(1, self.max_val)
      beta = -1.0
    elif self.error_metric == 'MAE':
      points = 0. + np.arange(0, self.max_val)
      beta = 0.0
    elif self.error_metric == 'RE':
      points = 0. + np.arange(0, self.max_val)
      beta = 1.0
    else:
      raise NotImplementedError('Error metric not implemented')
    pmf = sst.nbinom.pmf(points, self.params[1], self.params[2])
    if self.error_metric != 'MAPE':
      pmf *= 1.0 - self.params[0]
      pmf[0] += self.params[0]
    pmf *= (points**beta)
    spm = np.sum(pmf)
    if spm <= 0.0 or np.isnan(spm):
      pmf[0] = 1.0
      pmf[1::] = 0.0
      pmf = pmf / np.sum(pmf)
    else:
      pmf /= spm
    return sp.stats.rv_discrete(values=(points, pmf))


class NBinomMLE(BaseMLE):
  """nbinom mle."""

  def __init__(self, config, params, error_metric):
    """Initialize class with paraneters.

    Args:
      config: dict configuation parameters for this distribution
      params: list list of parameters for this dirtibtuion [zprob, n, p]
      error_metric: str error metric to optimize for
    """
    self.config = config
    self.params = params
    self.num_samples = config.get('num_samples', 10000)
    self.quantile = config.get('quantile', 0.5)
    self.min_val = config.get('min_val', 1.0)
    self.error_metric = error_metric
    super().__init__(config, params, error_metric)

  def sample(self, num_samples):
    """Get samples from MLE."""
    samples = sst.nbinom.rvs(
        n=self.params[0], p=self.params[1], size=num_samples)
    return samples.reshape(-1)

  @classmethod
  def inverse_link_function(cls, model_outputs, config=None):
    """Process model outputs and also the mean from the distribution."""
    params = np.zeros(shape=model_outputs.shape)
    params[:, 0] = _convert_to_positive(model_outputs[:, 0])
    params[:, 1] = 1.0 - _sigmoid(model_outputs[:, 1])
    nb_mean = params[:, 0] * (1 - params[:, 1]) / (params[:, 1] + EPS)
    return params, nb_mean


class MixZeroNBinomPareto(BaseMLE):
  """Mixture of zero, nbinom and pareto."""

  def __init__(self, config, params, error_metric):
    """Initialize class with paraneters.

    Args:
      config: dict configuation parameters for this distribution
      params: list list of parameters for this dirtibtuion [zprob, n, p]
      error_metric: str error metric to optimize for
    """
    self.config = config
    self.params = params
    self.num_samples = config.get('num_samples', 10000)
    self.quantile = config.get('quantile', 0.5)
    self.min_val = config.get('min_val', 1.0)
    self.alpha = config.get('alpha', 4.0)
    self.error_metric = error_metric
    super().__init__(config, params, error_metric)

  def sample(self, num_samples):
    """Get samples from MLE."""
    mix_weights = self.params[0:3]
    mix_weights = np.array(mix_weights)
    mix_weights[mix_weights < 0.0] = 0.0
    msum = np.sum(mix_weights)
    if msum != 1:
      mix_weights /= msum
    if np.isnan(msum) or msum <= 0.0:
      mix_weights = np.ones(shape=mix_weights.shape) / 3
    mix_ids = np.random.choice(
        a=[0, 1, 2], p=mix_weights, replace=True, size=num_samples)
    mult_mat = np.zeros(shape=(num_samples, 3))
    mult_mat[(np.arange(num_samples), mix_ids)] = 1.0
    sample_mat = np.zeros(shape=(num_samples, 3))
    n, p = self.params[3], self.params[4]
    if n <= 0:
      n = 1.0
      p = 1.0
    if p >= 1.0:
      p = 1.0
    if p <= 0.0:
      p = 0.0
    if np.isnan(n):
      n = 1.0
    if np.isnan(p):
      p = 1.0
    sample_mat[:, 1] = sst.nbinom.rvs(n=n, p=p, size=num_samples)
    sample_mat[:, 2] = (np.random.pareto(self.alpha, size=num_samples) +
                        1) * self.params[5]
    samples = np.sum(sample_mat * mult_mat, axis=1)
    return samples.reshape(-1)

  @classmethod
  def inverse_link_function(cls, model_outputs, config):
    """Process model outputs and also the mean from the distribution."""
    params = np.zeros(shape=model_outputs.shape)
    params[:, 0:3] = _softmax(model_outputs[:, 0:3], axis=1)
    params[:, 3] = _convert_to_positive(model_outputs[:, 3])
    # pin = _softplus(model_outputs[:, 4])
    # params[:, 4] = 1.0 / (1.0 + pin)
    params[:, 4] = 1.0 - _sigmoid(model_outputs[:, 4])
    params[:, 5] = _softplus(model_outputs[:, 5]) + 1e-3
    alpha = config.get('alpha', 4.0)
    params[:, 6] = alpha
    nb_mean = params[:, 1] * params[:, 3] * (1 - params[:, 4]) / (
        params[:, 4] + EPS)
    multiplier = alpha / (alpha - 1)
    g_mean = params[:, 2] * params[:, 5] * multiplier
    return params, nb_mean + g_mean


def _inference_helper(inference_class, params, config, error_metric):
  """Helper function for inference on one sample."""
  inf_obj = SUPPORTED_MLES[inference_class.lower()](
      params=params, config=config, error_metric=error_metric)
  return inf_obj.predict()


def run_inference(inference_class, model_outputs, config, error_metric,
                  num_jobs):
  """Main function to run inference in parallel.

  Args:
    inference_class: str class for MLE inference
    model_outputs: np array numpy array of model outputs
    config: dict mle config
    error_metric: str error metric
    num_jobs: int number of parallel jobs

  Returns:
    final predictions for the metric
  """
  param_list, mean_output = SUPPORTED_MLES[
      inference_class.lower()].inverse_link_function(model_outputs, config)
  param_list = param_list.tolist()
  if error_metric == 'MSE':
    return mean_output.reshape(-1)
  if num_jobs == 1:
    final_preds = []
    for param in param_list:
      final_preds.append(
          _inference_helper(inference_class, param, config, error_metric))
  else:
    with mp.Pool(processes=num_jobs) as pool:
      final_preds = pool.starmap(
          _inference_helper,
          zip(
              itertools.repeat(inference_class),
              param_list,
              itertools.repeat(config),
              itertools.repeat(error_metric),
          ))
  return np.array(final_preds).reshape(-1)

