import abc
from typing import NamedTuple, Sequence
from absl import logging

import numpy as np
import data
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC

import sklearn.preprocessing as skl_prep
import scipy

import jax
from jax import numpy as jnp
from jax import scipy as jsc

import math


class Query(NamedTuple):
  """Actions generated by a (randomized) policy when given a set of contexts.

  Attributes:
    actions: n-times-1 Array -- chosen (sampled) actions
    probabilities: n-times-1 Array -- corresponding probabilities
  """

  actions: np.ndarray
  probabilities: np.ndarray


class Policy(abc.ABC):
  """A Policy samples actions given contexts.
  """

  @abc.abstractmethod
  def query(self, contexts: np.ndarray) -> Query:
    """Returns actions and their probs sampled by Policy given the contexts.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
    Returns: A Tuple of arrays of actions (int) and corresponding probs (float)
    """

  @abc.abstractmethod
  def get_probs(self, contexts: np.ndarray) -> np.ndarray:
    """Returns probability distribution over actions for each context.

    The softmax policy is defined as a probability vector
      exp(alt_bin_labels / temp) / sum(exp(alt_bin_labels / temp))
      where temp is a temperature of a policy and
      alt_bin_labels is a binary encoding of labels altered by alter_labels(...)

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
    Returns:  Array of probabilities according to the policy, where K
      is the number of actions (size n-times-K).

    Raises:
      NotImplementedError: when contexts is not training or testing contexts
    """

# ==============================================================================
# ==============================================================================
# ==============================================================================
#         Softmax Policy Based on a reward model.
# ==============================================================================
# ==============================================================================
# ==============================================================================
  
class LinearSoftmaxPolicy(Policy):
  """Policy defined using an sklearn model of the reward and a softmax transform.

  Attributes:
    train_contexts: A n-times-d array of training contexts(d=data dim., n=sample size).
    train_labels: A n-array of training labels.
    temperature: A positive float controlling the temp. of a Softmax policy.
    rand: Random state of numpy.random.RandomState type.
  """

  def __init__(
      self,
      train_contexts: np.ndarray,
      train_labels: np.ndarray,
      temperature: float,
      num_actions: int,
      eps_greedy: float = -1,
      model: str = 'logreg',
  ):
    """Constructs a Policy.


    Args:
      train_contexts:  Array of training contexts (n-times-d, d=data dim., n=sample size).
      train_labels: Array of training labels (size n).
      temperature: Positive float controlling the temperature of a Softmax policy.
    """
    
    self.train_contexts = train_contexts
    self.train_labels = train_labels
    self.temperature = temperature
    self.num_actions = num_actions
    self.action_set = np.arange(num_actions)
    self.reset_noise(0)

    self.eps_greedy = eps_greedy
    self.model = model

    self.train_reward_model()


  def reset_noise(self, seed: int):
    """Resets a random state given a seed.

    Args:
      seed: Integer seed for random state
    """
    self.rand = np.random.RandomState(seed)

  def train_reward_model(self):
    # print('Training the reward model...')
    if self.model == 'logreg':
      self.logreg = LogisticRegression(C = 1e-3)
    elif self.model == 'tree': 
      self.logreg = DecisionTreeClassifier(max_depth=8)
    elif self.model == 'knn': 
      self.logreg = KNeighborsClassifier()
    elif self.model == 'svm': 
      self.logreg = SVC(C = 1e-1, probability=True)
    elif self.model == 'forest':
      self.logreg = RandomForestClassifier(max_depth=5)
    self.logreg.fit(self.train_contexts, self.train_labels)
    # print('Training complete!')
    if len(self.logreg.classes_) != self.num_actions:
      print('Classes of LogReg do not cover all actions.') 


  def get_probs(self, contexts: np.ndarray):
    """Returns probability distribution over actions for given contexts.

    The softmax policy is defined as a probability vector
      exp(r(a, x) / temp) / sum(exp( r(a', x) / temp))
      where temp is a temperature of a policy and
      r(a, x) the predicted probabilities of the action a by the reward model

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size)

    Returns:  Array of probabilities according to the policy, where K is the number of actions (size n-times-K).
    """

    if self.eps_greedy > 0:
      labels = self.logreg.predict(contexts)
      v = (1 - self.eps_greedy) * np.eye(self.num_actions)[labels] + self.eps_greedy / self.num_actions
      return v
    
    scores = self.logreg.predict_proba(contexts)
    v = np.exp(scores / self.temperature)
    v = v / v.sum(axis=1)[:, np.newaxis]

    return v


  def query(self, contexts: np.ndarray) -> Query:
    """Returns actions and their probs sampled for the given contexts.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size)

    Returns: A Tuple of arrays of actions (int) and corresponding probs (float)
    """
    actions, p_0 = [], []
    probs = self.get_probs(contexts)
    for pi in probs:
      a = np.random.choice(self.action_set, p=pi)
      actions.append(a)
      p_0.append(pi[a])

    return Query(np.array(actions), np.array(p_0))

  def __str__(self):
    """Returns a string representation of a policy with parametrization."""
    return f"LinearSotfmaxPolicy(τ={self.temperature})"
  

# ==============================================================================
# ==============================================================================
# ==============================================================================
#         Softmax Memorization/Oracle Policy.
# ==============================================================================
# ==============================================================================
# ==============================================================================


class SoftmaxDataPolicy(Policy):
  """Memorization policy (using true labels).

  This object holds interaction sample and testing sample
  (each of which consists of context and labels).
  When either set of contexts is passed to the policy (get_probs(...))
  it returns action probabilities associated with those contexts.
  Note that this is a mock-up policy, so only one of the two samples is
  supported.


  Attributes:
    action_set: A list of unique integer actions.
    inter_contexts: A n-times-d array of interaction contexts(d=data dim., n=sample size).
    inter_labels: A n-array of interaction labels.

    test_contexts: A n-times-d array of test contexts(d=data dim., n'=sample size).
    test_labels: A n'-array of test labels.

    temperature: A positive float controlling the temp. of a Softmax policy.
    faulty_actions: A list of labels where the behavior policy makes mistakes.
    rand: Random state of numpy.random.RandomState type.
  """

  def __init__(
      self,
      dataset: data.Dataset,
      temperature: float,
      faulty_actions: Sequence[int],
      eps_greedy: float = -1,
  ):
    """Constructs a Policy.


    Args:
      dataset: an openML dataset 
      temperature: Positive float controlling the temperature of a Softmax
        policy.
      faulty_actions: List of labels on which the behavior policy makes
        mistakes.
    """

    self.train_contexts = dataset.contexts_train
    self.train_labels = dataset.labels_train

    self.inter_contexts = dataset.contexts_log
    self.inter_labels = dataset.labels_log

    self.test_contexts = dataset.contexts_test
    self.test_labels = dataset.labels_test

    self.num_actions = dataset.num_actions
    self.action_set = np.arange(self.num_actions)
    self.temperature = temperature
    self.faulty_actions = set(faulty_actions)
    self.reset_noise(0)

    self.eps_greedy = eps_greedy

  def reset_noise(self, seed: int):
    """Resets a random state given a seed.

    Args:
      seed: Integer seed for random state
    """
    self.rand = np.random.RandomState(seed)

  def alter_labels(self, labels: np.ndarray):
    """Returns altered labels according to the self.faulty_actions spec.

    Labels are altered by shifting each label contained in self.faulty_action
    to one forward (or to 0 if we have an overflow).

    Args:
      labels: Vector of labels (size 1 by n=sample size)

    Returns:
      A vector of the same size with all entries in self.faulty_actions shifted.
    """
    num_actions = len(self.action_set)

    fault = np.zeros(len(labels))
    for i in range(len(labels)):
      if labels[i] in self.faulty_actions:
        fault[i] = 1

    return (labels + fault) % num_actions  # faulty actions get shifted by one

  def get_probs(self, contexts: np.ndarray):
    """Returns probability distribution over actions for given contexts.

    The softmax policy is defined as a probability vector
      exp(alt_bin_labels / temp) / sum(exp(alt_bin_labels / temp))
      where temp is a temperature of a policy and
      alt_bin_labels is a binary encoding of labels altered by alter_labels(...)

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
    Returns:  Array of probabilities according to the policy, where K
      is the number of actions (size n-times-K).

    Raises:
      NotImplementedError: when contexts is not training or testing contexts
    """

    # predictions get altered by internal noise :
    if contexts is self.train_contexts:
      alt_labels = self.alter_labels(self.train_labels)
    elif contexts is self.inter_contexts:
      alt_labels = self.alter_labels(self.inter_labels)
    elif contexts is self.test_contexts:
      alt_labels = self.alter_labels(self.test_labels)
    else:
      raise NotImplementedError

    bin_alt_labels = skl_prep.label_binarize(
        alt_labels, classes=self.action_set)

    if self.eps_greedy > 0:
      v = (1 - self.eps_greedy) * bin_alt_labels + self.eps_greedy / self.num_actions
    else:
      v = np.exp(bin_alt_labels / self.temperature)
      v = v / v.sum(axis=1)[:, np.newaxis]

    return v

  def query(self, contexts: np.ndarray) -> Query:
    """Returns actions and their probs sampled for the given contexts.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either interaction or testing contexts provided by dataset during the
        initialization.
    Returns: A Tuple of arrays of actions (int) and corresponding probs (float)
    """
    actions, p_0 = [], []
    probs = self.get_probs(contexts)
    for pi in probs:
      a = np.random.choice(self.action_set, p=pi)
      actions.append(a)
      p_0.append(pi[a])

    return Query(np.array(actions), np.array(p_0))

  def __str__(self):
    """Returns a string representation of a policy with parametrization."""
    return f"SoftmaxDataPolicy(τ={self.temperature}, fauly_actions=[{str(self.faulty_actions)}])"


# ==============================================================================
# ==============================================================================
# ==============================================================================
#         Softmax Policy trained on Interaction Data.
# ==============================================================================
# ==============================================================================
# ==============================================================================
  

def log_vhat_importance_weighting(
    parameters: np.ndarray,
    temperature: float,
    contexts: np.ndarray,
    actions: np.ndarray,
    rewards: np.ndarray,
    b_prob: np.ndarray,
) -> np.ndarray:
  """Returns the log of importance weighted estimator.

  Returns the log of importance weighted estimator where each
  importance weight is computed w.r.t. the softmax target policy defined
  w.r.t. a linear model as defined in the description of a class.

  Args:
    parameters: Parameters of the linear model of a target policy.
    temperature: Positive float controlling the temperature of a Softmax
      policy.
    contexts: Array of contexts (n-times-d, d=data dim., n=sample size).
    actions: Actions (integers).
    rewards: Rewards (float).
    b_prob: Probabilities corresponding to (context, action) pairs
      according to the behavior policy.
  Returns: The logarithm of importance-weighted estimate.
  """
  n, _ = contexts.shape
  v = (1.0 / temperature) * contexts.dot(parameters)
  pot = (1.0 / temperature) * (contexts *
                               parameters[:, actions].T).sum(axis=1)

  a = jnp.log(rewards / (n * b_prob)) - jsc.special.logsumexp(v, axis=1)
  rs = jsc.special.logsumexp(pot + a, axis=0)

  return rs


def log_vhat_sn_importance_weighting(
    parameters: np.ndarray,
    temperature: float,
    contexts: np.ndarray,
    actions: np.ndarray,
    rewards: np.ndarray,
    b_prob: np.ndarray,
) -> np.ndarray:
  """Returns a log of self-normalized (SN) importance weighted estimator.

  Returns a log of (SN) importance weighted estimator where each
  importance weight is computed w.r.t. the softmax target policy defined
  w.r.t. a linear model as defined in the description of a class.

  Args:
    parameters: Parameters of the linear model of a target policy.
    temperature: Positive float controlling the temperature of a Softmax
      policy.
    contexts: Array of contexts (n-times-d, d=data dim., n=sample size).
    actions: Actions (integers).
    rewards: Rewards (float).
    b_prob: Probabilities corresponding to (context, action) pairs
      according to the behavior policy.
  Returns: The logarithm of SN importance-weighted estimate.
  """
  v = (1.0 / temperature) * contexts.dot(parameters)
  pot = (1.0 / temperature) * (contexts *
                               parameters[:, actions].T).sum(axis=1)

  a = jnp.log(rewards / b_prob) - jsc.special.logsumexp(v, axis=1)
  ln_numer = jsc.special.logsumexp(pot + a, axis=0)

  a = -jnp.log(b_prob) - jsc.special.logsumexp(v, axis=1)
  ln_denom = jsc.special.logsumexp(pot + a, axis=0)

  return ln_numer - ln_denom
    

class SoftmaxGAPolicy(Policy):
  """Softmax gradient ascent fitted policy.

  This softmax policy is defined as a probability vector
  x |-> exp(<W,x> / temp) / sum(exp(<W,x> / temp))
      where temp is a temperature of a policy and
      W is a K-times-d matrix of parameters (here K is a number of actions
      and d is a context dimension).
  Parameters W are fitted by the gradient ascent either w.r.t. the
  importance-weighted estimator or its self-normalized version.

  Attributes:
    n_actions: Number of actions.
    temperature: Positive float controlling the temp. of a Softmax policy.
    steps: Number of gradient ascent steps for fitting the policy
      step_size: step size of the gradient ascent for fitting the policy.
    obj_type: Objective type, TrainedPolicyObjType.IW = importance-weighted
      estimator TrainedPolicyObjType.SNIW = self-normalized importance-weighted
      estimator.
    parameters: Parameters of the linear model in the softmax policy
    ln_obj: Reference to a static method implementing the
      log-objective function.
  """

  def __init__(
      self,
      action_set: Sequence[int],
      temperature: float,
      steps: int = 10000,
      step_size: float = 1e-2,
      obj_type: str = 'IW',
  ):
    """Constructs a Softmax Gradient Ascent Policy.

    Args:
      action_set: List of unique integer actions.
      temperature: Positive float controlling the temperature of a Softmax
        policy.
      steps: Number of gradient ascent steps for fitting the policy.
      step_size: Step size of the gradient ascent for fitting the policy.
      obj_type: Objective type, TrainedPolicyObjType.IW = importance-weighted
        estimator TrainedPolicyObjType.SNIW = self-normalized
        importance-weighted estimator.
    """
    self.n_actions = len(action_set)
    self.temperature = temperature
    self.steps = steps
    self.step_size = step_size
    self.parameters = None

    self.obj_type = obj_type
    if obj_type == 'IW':
      self.ln_obj = log_vhat_importance_weighting
    elif obj_type == 'SNIW':
      self.ln_obj = log_vhat_sn_importance_weighting
    else:
      raise NotImplementedError

  def train(
      self,
      contexts: np.ndarray,
      actions: np.ndarray,
      rewards: np.ndarray,
      b_prob: np.ndarray,
  ):
    """Fits the softmax policy according to the chosen objective.

    Fits the softmax policy according to the objective chosen during
    initialization. The gradient ascent is run for a fixed number of
    steps and a step size (specified during initialization).
    Gradient computation is done through autodiff jax library.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size)
      actions: Actions (integers).
      rewards: Rewards (float).
      b_prob: Probabilities corresponding to (context, action) pairs
        according to the behavior policy.
    """
    contexts = jnp.array(contexts)
    actions = jnp.array(actions)
    rewards = jnp.array(rewards)
    b_prob = jnp.array(b_prob)

    _, d = contexts.shape

    grad_v = jax.jit(jax.grad(self.ln_obj))

    obj_params = (self.temperature, contexts, actions, rewards, b_prob)

    logging.debug("%s(softmax): iter\t\temp_value ", self.obj_type)
    logging.debug("%s(softmax): --------------------------------- ",
                  self.obj_type)

    def update_step_ga(_, parameters: np.ndarray):
      """Returns updated parameters after a single step of gradient ascent.

      Args:
        _: gradient ascent step
        parameters: Parameters to be updated.
      Returns: Updated parameters.
      """
      g = grad_v(parameters, *obj_params)
      parameters += self.step_size * g
      return parameters

    parameters_init = np.zeros(shape=(d, self.n_actions))
    parameters_init = jnp.array(parameters_init)

    self.parameters = jax.lax.fori_loop(0, self.steps, update_step_ga,
                                        parameters_init)

    logging.debug("%s(softmax): %d\t\t%.2f ", self.obj_type, self.steps,
                  math.exp(self.ln_obj(self.parameters, *obj_params)))

  def get_probs(self, contexts: np.ndarray):
    """Returns probability distribution over actions for the given contexts.

    The softmax policy is defined as a probability vector
      exp(<W,x> / temp) / sum(exp(<W,x> / temp))
      where temp is a temperature of a policy and
      W is a K-times-d matrix of parameters (here K is a number of actions
      and d is a context dimension) fitted by gradient ascent.

    Args:
      contexts:  Array of contexts (n-times-d, d=data dim., n=sample size).
    Returns: Array of probabilities according to the policy.
    """
    v = (1.0 / self.temperature) * contexts.dot(self.parameters)
    logprob = v - np.expand_dims(scipy.special.logsumexp(v, axis=1), axis=1)

    return np.exp(logprob)


  def query(self, contexts: np.ndarray) -> Query:
    """Returns actions and their probs sampled by the policy given the contexts.

    Args:
      contexts: Array of contexts (n-times-d, d=data dim., n=sample size), which
        are either training or testing contexts provided during the
        initialization.
    Returns: Array integer actions and array of corresponding probabilities.
    """
    actions, p_0 = [], []
    probs = self.get_probs(contexts)
    for pi in probs:
      a = np.random.choice(self.action_set, p=pi)
      actions.append(a)
      p_0.append(pi[a])

    return Query(np.array(actions), np.array(p_0))
  

  def __str__(self):
    """Returns a string representation of a policy with parametrization."""
    return ("Softmax (linear potential): %s max`d by GA (T=%s, eta=%s)" %
            (self.obj_type, self.steps, self.step_size))
