# pylint: disable=g-bad-file-header
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Classification-based testbed based around a logit_fn and x_generator."""

from typing import Callable, Optional, Tuple

import chex
from enn import metrics
import haiku as hk
import jax
import jax.numpy as jnp
from neural_testbed import base as testbed_base
from neural_testbed import likelihood

LogitFn = Callable[[chex.Array], chex.Array]  # x -> logits
XGenerator = Callable[[chex.PRNGKey, int], chex.Array]  # key, num_samples -> x


class ClassificationEnvLikelihood(likelihood.GenerativeDataSampler):
  """Classification-based environment-based inference."""

  def __init__(self,
               logit_fn: LogitFn,
               x_train_generator: XGenerator,
               x_test_generator: XGenerator,
               num_train: int,
               key: chex.PRNGKey,
               override_train_data: Optional[testbed_base.Data] = None,
               tau: int = 1):

    rng = hk.PRNGSequence(key)

    self._logit_fn = logit_fn
    self._tau = tau
    self._x_test_generator = x_test_generator
    self._num_train = num_train

    # Optionally override training data where you want to allow for training
    # data that was *not* generated by the x_generator, logit_fn.
    if override_train_data is None:
      self._train_data, _ = sample_gaussian_data(
          logit_fn, x_train_generator, num_train, next(rng))
    else:
      assert num_train == override_train_data.x.shape[0]
      assert num_train == override_train_data.y.shape[0]
      self._train_data = override_train_data

    # Generate canonical x_test for DEBUGGING ONLY!!!
    num_test = 1000
    self._x_test = self._x_test_generator(next(rng), num_test)
    test_logits = self._logit_fn(self._x_test)  # [n_train, n_class]
    chex.assert_shape(test_logits, [num_test, None])
    self._test_probs = jax.nn.softmax(test_logits)

  @property
  def train_data(self) -> testbed_base.Data:
    return self._train_data

  @property
  def test_x(self) -> chex.Array:
    """Canonical test data for debugging only.

    This is not the test data x returned by the test data method.
    """
    return self._x_test

  @property
  def probabilities(self) -> chex.Array:
    """Return probabilities of classes for canonical test x.

    Use only for debugging/plotting purposes in conjunction with the test_x
    method. The test_data method does not use the same test_x.
    """
    return self._test_probs

  def test_data(self, key: chex.PRNGKey) -> Tuple[testbed_base.Data, float]:
    """Generates test data and evaluates log likelihood w.r.t. environment.

    The test data that is output will be of length tau examples.
    We wanted to "pass" tau here... but ran into jax.jit issues.

    Args:
      key: Random number generator key.

    Returns:
      Tuple of data (with tau examples) and log-likelihood under posterior.
    """
    def sample_test(k: chex.PRNGKey) -> Tuple[testbed_base.Data, float]:
      return sample_gaussian_data(
          self._logit_fn, self._x_test_generator, self._tau, key=k)
    return jax.jit(sample_test)(key)


def make_gaussian_sampler(input_dim: int) -> XGenerator:
  def gaussian_generator(key: chex.PRNGKey, num_samples: int) -> chex.Array:
    return jax.random.normal(key, [num_samples, input_dim])
  return gaussian_generator


# TODO(author2): Migrate to experimental directory.
def make_weibull_sampler(input_dim: int) -> XGenerator:
  """Returns Weibull sampler around initial reference point."""
  # TODO(smsaghari): Expose concentration and scale as parameters
  concentration = jnp.log10(jnp.log2(10))
  scale = 1 / (jnp.log(10)**(1 / concentration))

  def weibull_generator(key: chex.PRNGKey, tau: int) -> chex.Array:
    key_ref, key_dist, key_perturb = jax.random.split(key, 3)
    x_ref = jax.random.normal(key_ref, [input_dim])

    distances = jnp.concatenate([
        jnp.zeros(1),
        jax.random.weibull_min(key_dist, scale, concentration, shape=[tau - 1])
    ])
    chex.assert_shape(distances, [tau])

    perturbations = jax.random.normal(key_perturb, [tau, input_dim])

    x_test = x_ref + jnp.einsum('ij, i -> ij', perturbations, distances)
    chex.assert_shape(x_test, [tau, input_dim])
    return x_test

  return weibull_generator


def make_polyadic_sampler(input_dim: int, kappa: int = 2) -> XGenerator:
  """Samples with local structure centered around kappa N(0, 1) anchor points.

  To make this work in jax we actually implement this by first sampling kappa
  anchor points, then randomly the tau batch points from these kappa anchors
  (with replacement) and then adding noise.

  Args:
    input_dim: input dimension.
    kappa: number of anchor reference points. If tau is less than kappa we
      default to sampling tau points.

  Returns:
    Polyadic sampling XGenerator.
  """

  def polyadic_generator(key: chex.PRNGKey, tau: int) -> chex.Array:
    anchor_key, sample_key = jax.random.split(key)
    # Sample anchor points
    anchor_x = jax.random.normal(anchor_key, [kappa, input_dim])

    # Index into these points
    sample_idx = jax.random.randint(sample_key, [tau], 0, kappa)
    repeat_x = anchor_x[sample_idx]
    chex.assert_shape(repeat_x, [tau, input_dim])
    return repeat_x

  return polyadic_generator


def sample_gaussian_data(logit_fn: LogitFn,
                         x_generator: XGenerator,
                         num_train: int,
                         key: chex.PRNGKey) -> Tuple[testbed_base.Data, float]:
  """Generates training data for given problem."""
  x_key, y_key = jax.random.split(key, 2)

  # Checking the dimensionality of our data coming in.
  x_train = x_generator(x_key, num_train)
  input_dim = x_train.shape[1]
  chex.assert_shape(x_train, [num_train, input_dim])

  # Generate environment function across x_train
  train_logits = logit_fn(x_train)  # [n_train, n_class]
  num_classes = train_logits.shape[-1]  # Obtain from logit_fn.
  chex.assert_shape(train_logits, [num_train, num_classes])
  train_probs = jax.nn.softmax(train_logits)

  # Generate training data.
  def sample_output(probs: chex.Array, key: chex.PRNGKey) -> chex.Array:
    return jax.random.choice(key, num_classes, shape=(1,), p=probs)
  y_keys = jax.random.split(y_key, num_train)
  y_train = jax.vmap(sample_output)(train_probs, y_keys)
  data = testbed_base.Data(x=x_train, y=y_train)

  # Compute the log likelihood with respect to the environment
  log_likelihood = metrics.categorical_log_likelihood(train_probs, y_train)
  return data, log_likelihood
