import numpy as np

from Agents.TabularAgent import TabularAgent

class IPGOmega(TabularAgent):
  def __init__(self, env, γ, α_π, ω, semigrad=True):
    ''' Constructor
    Args:
      env: Environment instance
      α_π (float): Step size for Increinforce
      γ (float): Discount factor
    '''
    self._env = env
    self._γ = γ
    self._α_π = α_π
    self._ω = ω
    self._semigrad = semigrad

  def init(self):
    ''' Agent initialization
    '''
    self.π_logits = np.zeros((self._env.n_states, self._env.n_actions))

  def π(self, s=None):
    ''' Compute agent policy from logits
    Args:
      s (int): An environment state
    Returns:
      π (numpy array): Action probabilities π[s, a] (full policy if s is None)
    '''
    if s == None:
      exp = np.exp(self.π_logits - self.π_logits.max(1, keepdims=True).repeat(self._env.n_actions, axis=1))
      return exp / exp.sum(1, keepdims=True).repeat(self._env.n_actions, 1)
    else:
      exp = np.exp(self.π_logits[s] - self.π_logits[s].max())
      return exp / exp.sum()
    
  def parameters(self):
    ''' Get agent's parameters (e.g., value function, action preferences)
    Returns:
      params (dict of numpy arrays): Agent's parameters
    '''
    π = self.π()
    return {
            'π': π,
           }

  def choose_action(self, s):
    ''' Agent policy
    Args:
      s (int): An environment state
    Returns:
      a (int): Action sampled from policy at given state
    '''
    return np.random.choice(self._env.n_actions, p=self.π(s))

  def update(self, s, a, r, sp, T):
    ''' Perform agent updates given a transition
    Args:
      s (int): Previous state
      a (int): Action taken
      r (float): Reward received
      sp (int): Next state
      T (bool): Terminal state
    '''
    π = self.π(s)
    dlogπ = (np.diag(π) - π * π[None].T)[:, a] / (1e-12 + π[a])
    self.e *= self._γ
    self.e[s] += (1 if self._semigrad else self._γ ** self.t) * dlogπ
    self.Δπ_logits = (1 - self._ω) * self.Δπ_logits + self._α_π * r * self.e
    self.π_logits += self._ω * self.Δπ_logits
    self.t += 1
    if T:
      self.π_logits += (1 - self._ω) * self.Δπ_logits

  def reset(self):
    ''' Agent reset (initial environment observation)
    '''
    self.s = self._env.reset()
    self.e = np.zeros((self._env.n_states, self._env.n_actions))
    self.Δπ_logits = np.zeros((self._env.n_states, self._env.n_actions))
    self.t = 0