"""Local Interpretable Model-agnostic Explanations (LIME).

LIME was proposed in the following paper:

> "Why Should I Trust You?": Explaining the Predictions of Any Classifier
> Marco Tulio Ribeiro, Sameer Singh, Carlos Guestrin
> https://arxiv.org/abs/1602.04938

LIME explains classifiers by returning a feature attribution score
for each input feature. It works as follows:

1) Sample perturbation masks. First the number of masked features is sampled
   (uniform, at least 1), and then that number of features are randomly chosen
   to be masked out (without replacement).
2) Get predictions from the model for those perturbations. Use these as labels.
3) Fit a linear model to associate the input positions indicated by the binary
   mask with the resulting predicted label.

The resulting feature importance scores are the linear model coefficients for
the requested output class or (in case of regression) the output score.

This is a reimplementation of the original https://github.com/marcotcr/lime
and is tested for compatibility. This version supports applying LIME to text
input, also in case of regression and binary-classification where the
prediction function only outputs a scalar for each input sentence.
"""

from collections.abc import Callable, Iterable, Sequence
import functools
from typing import Any, Optional
import attr
import numpy as np
from sklearn import linear_model
from sklearn import metrics



DEFAULT_MASK_TOKEN = '[MASK]'
DEFAULT_NUM_SAMPLES = 3000
DEFAULT_SOLVER = 'cholesky'
DEFAULT_KERNEL_WIDTH = 25

def exponential_kernel(
    distance: float, kernel_width: float = DEFAULT_KERNEL_WIDTH) -> np.ndarray:
  """The exponential kernel."""
  return np.sqrt(np.exp(-(distance**2) / kernel_width**2))

@attr.s(auto_attribs=True)
class PosthocExplanation:
  """Represents a post-hoc explanation with feature importance scores.

  Attributes:
    features: the names of the features to attribute to;
      typically these are tokens.
    feature_importance: Feature importance scores for each input feature. These
      are the coefficients of the linear model that was fitted to mimic the
      behavior of a (black-box) prediction function.
    intercept: The intercept of the fitted linear model. This is the independent
      term that is added to make a prediction.
    model: The fitted linear model. An explanation only contains this if it was
      explicitly requested from the explanation method.
    score: The R^2 score of the fitted linear model on the perturbations and
      their labels. This reflects how well the linear model was able to fit to
      the perturbation set.
    prediction: The prediction of the linear model on the full input sentence,
      i.e., an all-true boolean mask.
  """
  features: Sequence[str]
  feature_importance: np.ndarray
  intercept: Optional[float] = None
  model: Optional[Any] = None
  score: Optional[float] = None
  prediction: Optional[float] = None

def sample_masks(num_samples: int,
                 num_features: int,
                 seed: Optional[int] = None):
  """Samples LIME masks with at least 1 position disabled per sampled mask.

  The number of disabled features is sampled from a uniform distribution.

  Args:
    num_samples: The number of samples.
    num_features: The number of features to sample a mask for. Typically this is
      the number of tokens in the sentence.
    seed: Set this to an integer to make the sampling deterministic.

  Returns:
    Masks <bool>[num_samples, num_features] indicating which features are
    enabled (True) and which ones are disabled (False).
  """
  rng = np.random.RandomState(seed)
  positions = np.tile(np.arange(num_features), (num_samples, 1))
  permutation_fn = np.vectorize(rng.permutation, signature='(n)->(n)')
  permutations = permutation_fn(positions)  # A shuffled range of positions.
  num_disabled_features = rng.randint(1, num_features + 1, (num_samples, 1))
  # For num_disabled_features[i] == 2, this will set indices 0 and 1 to False.
  return permutations >= num_disabled_features


def get_perturbations(tokens: Sequence[str],
                      masks: np.ndarray,
                      mask_token: str = '<unk>') -> Iterable[str]:
  """Returns strings with the masked tokens replaced with `mask_token`."""
  for mask in masks:
    parts = [t if mask[i] else mask_token for i, t in enumerate(tokens)]
    yield ' '.join(parts)


def explain(
    sentence: str,
    predict_fn: Callable[[Iterable[str]], np.ndarray],
    class_to_explain: Optional[int] = None,
    num_samples: int = DEFAULT_NUM_SAMPLES,
    tokenizer: Any = str.split,
    mask_token: str = DEFAULT_MASK_TOKEN,
    alpha: float = 1.0,
    solver: str = DEFAULT_SOLVER,
    kernel: Callable[..., np.ndarray] = exponential_kernel,
    distance_fn: Callable[..., np.ndarray] = functools.partial(
        metrics.pairwise.pairwise_distances, metric='cosine'),
    distance_scale: float = 100.,
    return_model: bool = False,
    return_score: bool = False,
    return_prediction: bool = False,
    seed: Optional[int] = None,
) -> PosthocExplanation:
  """Returns the LIME explanation for a given sentence.

  By default, this function returns an explanation object containing feature
  importance scores and the intercept. Optionally, more information can be
  returned, such as the linear model, the score of the model on the perturbation
  set, and the prediction that the linear model makes on the original sentence.

  Args:
    sentence: An input to be explained.
    predict_fn: A prediction function that returns an array of outputs given a
      list of inputs. The output shape is [len(inputs)] for regression and
      binary classification (with scalar output), and [len(inputs), num_classes]
      for multi-class classification.
    class_to_explain: The class ID to explain in case of multi-class
      classification, where `predict_fn` returns outputs with multiple
      dimensions for each input. For example, use 2 to explain the third class
      in 3-class classification. For regression and binary classification, where
      `predict_fn` returns a scalar for each input, this does not need to be
      set.
    num_samples: The number of n-grams to sample.
    tokenizer: A function that splits the input sentence into tokens.
    mask_token: The token that is used for masking tokens, e.g., '<unk>'.
    alpha: Regularization strength of the linear approximation model. See
      `sklearn.linear_model.Ridge` for details.
    solver: Solver to use in the linear approximation model. See
      `sklearn.linear_model.Ridge` for details.
    kernel: A kernel function to be used on the distance function. By default,
      use the exponential kernel with kernel width utils.DEFAULT_KERNEL_WIDTH.
    distance_fn: A distance function to use in range [0, 1]. Default: cosine.
    distance_scale: A scalar factor multiplied with the distances before the
      kernel is applied.
    return_model: Returns the fitted linear model.
    return_score: Returns the score of the linear model on the perturbations.
      This is the R^2 of the linear model predictions w.r.t. their targets.
    return_prediction: Returns the prediction of the linear model on the full
      original sentence.
    seed: Optional random seed to make the explanation deterministic.

  Returns:
    The explanation for the requested class.
  """
  # TODO(bastings): Provide sentence already tokenized to reduce split/join ops.
  tokens = tokenizer(sentence)

  if not tokens:
    return PosthocExplanation(
        features=[], feature_importance=np.array([], dtype=np.float32))

  masks = sample_masks(num_samples + 1, len(tokens), seed=seed)
  assert masks.shape[0] == num_samples + 1, 'Expected num_samples + 1 masks.'
  all_true_mask = np.ones_like(masks[0], dtype=bool)
  masks[0] = all_true_mask  # First mask is the full sentence.

  perturbations = list(get_perturbations(tokens, masks, mask_token))
  outputs = predict_fn(perturbations)

  if len(outputs.shape) > 1:
    assert class_to_explain is not None, \
        'class_to_explain needs to be set when `predict_fn` returns a 2D tensor'
    outputs = outputs[:, class_to_explain]  # We are only interested in 1 class.

  distances = distance_fn(all_true_mask.reshape(1, -1), masks).flatten()
  distances = distance_scale * distances
  distances = kernel(distances)

  # Fit a linear model for the requested output class.
  model = linear_model.Ridge(
      alpha=alpha, solver=solver, random_state=seed).fit(
          masks, outputs, sample_weight=distances)

  explanation = PosthocExplanation(
      features=tokens,
      feature_importance=model.coef_,
      intercept=model.intercept_)

  if return_model:
    explanation.model = model

  if return_score:
    explanation.score = model.score(masks, outputs)

  if return_prediction:
    explanation.prediction = model.predict(all_true_mask.reshape(1, -1))

  return explanation
