import numpy as np
import sklearn.preprocessing as skl_prep
from typing import Sequence, NamedTuple

class Query(NamedTuple):
  """Actions generated by a (randomized) policy when given a set of contexts.

  Attributes:
    actions: n-times-1 Array -- chosen (sampled) actions
    probabilities: n-times-1 Array -- corresponding probabilities
  """

  actions: np.ndarray
  probabilities: np.ndarray

class SoftmaxDataPolicy:
  """Memorization policy (using true labels).

  This object can hold either training sample or a testing sample
  (each of which consists of context and labels).
  When either set of contexts is passed to the policy (get_probs(...))
  it returns action probabilities associated with those contexts.
  Note that this is a mock-up policy, so only one of the two samples is
  supported.


  Attributes:
    action_set: A list of unique integer actions.
    train_contexts: A n-times-d array of training contexts(d=data dim., n=sample
      size).
    train_labels: A n-array of training labels.
    test_contexts: A n-times-d array of training contexts(d=data dim., n'=sample
      size).
    test_labels: A n'-array of training labels.
    temperature: A positive float controlling the temp. of a Softmax policy.
    faulty_actions: A list of labels where the behavior policy makes mistakes.
    rand: Random state of numpy.random.RandomState type.
  """

  def __init__(
      self,
      train_contexts: np.ndarray,
      train_labels: np.ndarray,
      test_contexts: np.ndarray,
      test_labels: np.ndarray,
      action_set: Sequence[int],
      temperature: float,
      faulty_actions: Sequence[int],
  ):
    """Constructs a Policy.


    Args:
      train_contexts:  Array of training contexts (n-times-d, d=data dim.,
        n=sample size).
      train_labels: Array of training labels (size n).
      test_contexts: Array of training contexts (n-times-d, d=data dim.,
        n'=sample size).
      test_labels: Array of training labels 9size n).
      action_set: List of unique integer actions.
      temperature: Positive float controlling the temperature of a Softmax
        policy.
      faulty_actions: List of labels on which the behavior policy makes
        mistakes.
    """
    self.action_set = action_set

    self.train_contexts = train_contexts
    self.train_labels = train_labels
    self.test_contexts = test_contexts
    self.test_labels = test_labels
    self.temperature = temperature
    self.faulty_actions = set(faulty_actions)
    self.reset_noise(0)

  def reset_noise(self, seed: int):
    """Resets a random state given a seed.

    Args:
      seed: Integer seed for random state
    """
    self.rand = np.random.RandomState(seed)

  def alter_labels(self, labels: np.ndarray):
    """Returns altered labels according to the self.faulty_actions spec.

    Labels are altered by shifting each label contained in self.faulty_action
    to one forward (or to 0 if we have an overflow).

    Args:
      labels: Vector of labels (size 1 by n=sample size)

    Returns:
      A vector of the same size with all entries in self.faulty_actions shifted.
    """
    num_actions = len(self.action_set)

    fault = np.zeros(len(labels))
    for i in range(len(labels)):
      if labels[i] in self.faulty_actions:
        fault[i] = 1

    return (labels + fault) % num_actions  # faulty actions get shifted by one

  def get_probs(self, context: str):
    """Returns probability distribution over actions for given contexts.

    The softmax policy is defined as a probability vector
      exp(alt_bin_labels / temp) / sum(exp(alt_bin_labels / temp))
      where temp is a temperature of a policy and
      alt_bin_labels is a binary encoding of labels altered by alter_labels(...)

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
    Returns:  Array of probabilities according to the policy, where K
      is the number of actions (size n-times-K).

    Raises:
      NotImplementedError: when contexts is not training or testing contexts
    """

    # predictions get altered by internal noise :
    if context == 'train':
      alt_labels = self.alter_labels(self.train_labels)
    elif context == 'test':
      alt_labels = self.alter_labels(self.test_labels)
    else:
      raise NotImplementedError

    bin_alt_labels = skl_prep.label_binarize(
        alt_labels, classes=self.action_set)

    noise = np.random.rand(*bin_alt_labels.shape)

    v = np.exp((bin_alt_labels + noise) / self.temperature)
    v = v / v.sum(axis=1)[:, np.newaxis]

    return v

  def get_probs_by_actions(self, contexts: np.ndarray, actions: np.ndarray):
    """Returns probabilities for each given action in each given context.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
      actions: Array of actions (integers) for which probabilies are
        requested.
    Returns: Probabilities according to the policy.
    """
    n = len(actions)
    all_probs = self.get_probs(contexts)
    probs = all_probs[np.arange(n), actions]
    return probs

  def query(self, contexts: np.ndarray) -> Query:
    """Returns actions and their probs sampled for the given contexts.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
    Returns: A Tuple of arrays of actions (int) and corresponding probs (float)
    """
    probs = self.get_probs(contexts)
    actions = [np.random.choice(self.action_set, p=pi) for pi in probs]

    n = probs.shape[0]
    probs_by_actions = probs[:n, actions]
    return Query(np.array(actions), np.array(probs_by_actions))

  def __str__(self):
    """Returns a string representation of a policy with parametrization."""
    return f"SoftmaxDataPolicy(τ={self.temperature}, fauly_actions=[{str(self.faulty_actions)}])"
