"""Games to be played."""
import logging
import numpy as np
import ml_collections
from typing import Text, Callable, Tuple, Any

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

import data_game


def get_logits_and_probs(
    consumers: jnp.ndarray,
    tau: float = 1.0
) -> Callable:
  """Construct a function which computes the logits and recommendation
  probabilities for a given set of fixed consumer vectors.

  Args:
    consumers: An `m x d` array of the `m` consumer embeddings.
    tau: Temperature parameter scaling the consumer-producer embedding dot
        products inside the softmax used to compute the recommendation
        probabilities.

  Returns:
    A function which takes in a `n x d` array of the `n` producer strategy
    vectors, and returns a tuple `(logits, probs)` where `logits` is an `m x n`
    array of logits constituting the `probs`, each row corresponding to the
    probability of recommendation for one of the `m` consumers.
  """
  def logits_and_probs(v: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    logits = consumers @ v.T / tau
    logits -= jsp.special.logsumexp(logits, axis=-1, keepdims=True)
    probs = jnp.exp(logits)
    return logits, probs

  return logits_and_probs


def get_utility(
    consumers: jnp.ndarray,
    tau: float = 1.0,
    version: Text = 'exposure'
) -> Callable:
  """
  Construct a function which computes either (expected) exposure, or the
  (expected) rating utilities, for a given set of fixed consumer vectors.

  Args:
    consumers: An `m x d` array of the `m` consumer embeddings.
    tau: Temperature parameter scaling the consumer-producer embedding dot
        products inside the softmax used to compute the recommendation
        probabilities.
    version: A string specifying the type of utility. Currently supported values
        are `'exposure'` and `'rating'`.

  Returns:
    A function which takes in a `n x d` array of the `n` producer
    strategy vectors, and returns an `n`-dimensional vector of the producer
    utilities.
  """
  def utility(v: jnp.ndarray) -> jnp.ndarray:
    dots = consumers @ v.T / tau
    probs = jax.nn.softmax(dots, axis=-1)

    if version == 'exposure':
      ret = probs.mean(0)
    elif version == 'rating':
      ret = (dots * probs).mean(0)
    else:
      raise NotImplementedError(version)
    return ret

  return utility


def sample_game(
    key: jnp.ndarray,
    n_producers: int,
    n_consumers: int,
    dim: int,
    nonnegative: bool = False
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Any]:
  """
  Generate a random game where consumers, as well as the initial producer
  strategy vectors, are drawn i.i.d. uniformly from the `S^{dim - 1}` sphere.

  Args:
    key: JAX PRNG key.
    n_producers: The number of producers (`n` in most of our notation).
    n_consumers: The number of consumers (`m` in most of our notation).
    dim: The dimension of the embeddings. Must be `2` if `small_arc == True`.
    nonnegative: If set to `True`, will generate a game where both consumers and
        (initial) producer positions lie in the first orthant (all coordinates
        are non-negative). Among else, this means that every strategy outside
        of the first orthant is dominated by some strategy within it.

  Returns:
    Tuple of the initial `n x d` producer and `m x d` consumer arrays.
  """
  def sample_locations(skey, n_locs):
    arr = random.normal(skey, (n_locs, dim))
    arr /= jnp.linalg.norm(arr, ord=2, axis=-1, keepdims=True)
    if nonnegative:  # absolute value doesn't change `l2` norm
      arr = jnp.abs(arr)
    return arr

  key, pkey = random.split(key)
  producers = sample_locations(pkey, n_locs=n_producers)

  key, ckey = random.split(key)
  consumers = sample_locations(ckey, n_locs=n_consumers)

  aux = None
  return (producers, consumers), aux


def recommender_game(
    key: jnp.ndarray,
    n_producers: int,
    dim: int,
    dataset: str,
    recommender: str,
    random_start: bool = True,
    normalize_consumers: bool = False,
    random_state: np.random.RandomState = None
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Any]:
  """Wraps `data_game.DataGame`.

  :param key:
    JAX PRNG key.
  :param n_producers:
    The number of producers (`n` in most of our notation).
  :param dim:
    The dimension of the embeddings. Must be `2` if `small_arc == True`.
  :param dataset:
    Either 'lastfm-360k' or 'ml-100k'. See `data_game.DataGame`.
  :param recommender:
    One of 'nmf', 'mf-unbiased', or 'mf-biased'. See `data_game.DataGame`.
  :param random_start:
    If `True`, the returned initial producer locations are sampled randomly on
    the sphere (constrained to the first orthant if `recommender == 'nmf'`).
    If `False`, the initial locations are subset of the "base producers", i.e.,
    the item embeddings returned by the recommender within the `DataGame`.
  :param normalize_consumers:
    If `True`, consumer embeddings (returned by the `DataGame`) are renormalized
    to the unit sphere. Otherwise returns the original learned embeddings.
  :param random_state:
    Optional numpy random state. Used by the `DataGame` to make training of the
    recommender and data pre-processing reproducible.
  :return:
    A tuple `(producers, consumers), aux` where `producers` and `consumers` are
    respectively `n x d` and `m x d` array of producer and consumers embeddings.
    See the `random_start` and `normalize_consumers` arguments for details. The
    `aux` is a dictionary which contains 'game' (a `data_game.DataGame` object),
    `base_producers` (the learned item embeddings returned by `DataGame`), and
    `sample_producers` (a function which returns a new sample of initial
    producer locations different from the above `producer` array).
  """
  # initialise the recommender
  game = data_game.DataGame(
    setting=f'{dataset}_{recommender}', d=dim, random_state=random_state)

  # fit the model & extract the final embeddings
  base_producers, consumers = game.fit_vectors(normalize=normalize_consumers)

  # sample producers
  def sample_producers(skey, n_producers):
    # TODO: sampling causes `inner_id` mismatch within `data_game`!
    if random_start:  # generate initial positions uniformly at random on sphere
      (producers, _), _ = sample_game(
        skey, n_producers=n_producers, n_consumers=len(consumers),
        dim=dim, nonnegative=recommender == 'nmf')
    else:  # randomly subset the existing "producers" (renormalised items)
      if n_producers > len(base_producers):
        raise ValueError(
          f'tried using more `n_producers` than there are items in the dataset: '
          f'`{n_producers} > {len(base_producers)}` (dataset "{dataset}")')

      pidx = jax.random.choice(
        skey, len(base_producers), shape=(n_producers,), replace=False)
      producers = base_producers[pidx, :]
      producers /= jnp.linalg.norm(producers, ord=2, axis=-1, keepdims=True)

    return producers

  key, pkey = random.split(key)
  n_producers = len(base_producers) if n_producers is None else n_producers
  producers = sample_producers(pkey, n_producers=n_producers)

  # not implementing consumer subsampling (would require changing `inner_id`!)

  aux = {
    'game': game,
    'base_producers': base_producers,
    'sample_producers': sample_producers
  }
  return (producers, consumers), aux


def game_from_config(
    key: jnp.ndarray,
    C: ml_collections.ConfigDict,
    rng: np.random.RandomState = None
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Any]:
  if C.dataset == 'random':
    (producers, consumers), aux = sample_game(
      key=key, dim=C.dim, n_producers=C.n_producers, n_consumers=C.n_consumers,
      nonnegative=C.nonnegative)
  else:  # will fail if `C.dataset` not a known dataset for `data_game.DataGame`
    if C.n_consumers is not None:
      raise NotImplementedError(
        'consumer subsetting not implemented for the recsys games')
    if C.nonnegative and C.recommender != 'nmf':
      logging.warning(
        f'`C.nonnegative` is ignored with `recsys` games; use `C.recommender` '
        f'which incorporates the non-negative constraint instead')

    (producers, consumers), aux = recommender_game(
      key=key, dim=C.dim, n_producers=C.n_producers, recommender=C.recommender,
      dataset=C.dataset, random_start=C.random_start, random_state=rng,
      normalize_consumers=C.normalize_consumers)

  return (producers, consumers), aux
