import numpy as np

class TabularAgent():
  def __init__(self, env, *args):
    ''' Constructor
    Args:
      env: Environment instance
      *args: Agent hyperparameters
    '''
    raise NotImplementedError

  def init(self):
    ''' Agent initialization
    '''
    raise NotImplementedError

  def parameters(self):
    ''' Get agent's parameters (e.g., value function, action preferences)
    Returns:
      params (dict of numpy arrays): Agent's parameters
    '''
    raise NotImplementedError

  def choose_action(self, s):
    ''' Agent policy
    Args:
      s (int): An environment state
    Returns:
      a (int): Action sampled from policy at given state
    '''
    raise NotImplementedError

  def update(self, s, a, r, sp, T):
    ''' Perform agent updates given a transition
    Args:
      s (int): Previous state
      a (int): Action taken
      r (float): Reward received
      sp (int): Next state
      T (bool): Terminal state
    '''
    raise NotImplementedError

  def reset(self):
    ''' Agent reset (initial environment observation)
    '''
    self.s = self._env.reset()

  def step(self):
    ''' Take an action and perform updates, returns transition for logging
    Returns:
      s (int): Previous state
      a (int): Action taken
      r (float): Reward received
      sp (int): Next state
      T (bool): Terminal state
      τ (bool): Time limit reached
    '''
    s, a = self.s, self.choose_action(self.s)
    sp, r, T, τ = self._env.step(a)
    self.update(s, a, r, sp, T)
    if not (T or τ):
      self.s = sp
    else:
      self.reset()
    #self.s = sp if not (T or τ) else self._env.reset()
    return s, a, r, sp, T, τ

  def episode(self):
    ''' Do an episode in the environment
    Returns:
      G: Episodic return
      ep_length: Episode length
    '''
    G = 0
    self.reset()
    for step in range(self._env.max_steps):
      _, _, r, _, T, τ = self.step()
      G += r
      if T or τ: break
    return G, step + 1

  def run_steps(self, steps, log_rewards=False, log_returns=False, log_ep_lengths=False, log_n_eps=False, log_params=False):
    ''' Run a number of steps in the environment
    Args:
      steps (int): Number of steps to run for
      log_rewards (bool): Whether to log the reward at each step
      log_returns (bool): Whether to log the previous episode's return at each step
      log_ep_len (bool): Whether to log the previous episode's length at each step
      log_n_eps (bool): Whether to log the total number of episodes completed by each step
      log_params (bool): Whether to log the agent's parameters at each step
    Returns:
      results (dict): Dictionary containing logged data
    '''
    self._env.init()
    self.init()
    self.reset()
    if log_rewards: rewards = np.zeros(steps)
    if log_returns: returns, ep_return_prev, ep_return = np.zeros(steps), 0.0, 0.0
    if log_ep_lengths: ep_lengths, ep_len_prev, ep_len = np.zeros(steps), 0, 0
    if log_n_eps: n_eps = np.zeros(steps)
    if log_params: params = {key: np.zeros((steps, *self.parameters()[key].shape)) for key in self.parameters().keys()}
    for step in range(steps):
      _, _, r, _, T, τ = self.step()
      if log_rewards: rewards[step] = r
      if log_returns:
        ep_return += r
        if T or τ: ep_return_prev, ep_return = ep_return, 0.0
        returns[step] = ep_return_prev
      if log_ep_lengths:
        ep_len += 1
        if T or τ: ep_len_prev, ep_len = ep_len, 0
        ep_lengths[step] = ep_len_prev
      if log_n_eps:
        n_eps[step] = (n_eps[step - 1] if step > 0 else 0) + (T or τ)
      if log_params:
        params_dict = self.parameters()
        for key in params_dict.keys(): params[key][step] = params_dict[key]
    results = {}
    if log_rewards: results['rewards'] = rewards
    if log_returns: results['returns'] = returns
    if log_ep_lengths: results['ep_lengths'] = ep_lengths
    if log_n_eps: results['n_eps'] = n_eps
    if log_params: results['params'] = params
    return results

  def run_episodes(self, episodes, log_returns=False, log_ep_lengths=False, log_params=False):
    ''' Run a number of episodes in the environment
    Args:
      episodes (int): Number of episodes to run for
      log_returns (bool): Whether to log the previous episode's return at each step
      log_params (bool): Whether to log the agent's parameters at each step
    Returns:
      results (dict): Dictionary containing logged data
    '''
    self._env.init()
    self.init()
    if log_returns: returns = np.zeros(episodes), 0.0, 0.0
    if log_ep_lengths: ep_lengths = np.zeros(episodes)
    if log_params: params = {key: np.zeros((episodes, *self.parameters()[key].shape)) for key in self.parameters().keys()}
    for ep in range(episodes):
      G, ep_length = self.episode()
      if log_returns: returns[ep] = G
      if log_ep_lengths: ep_lengths[ep] = ep_length
      if log_params:
        params_dict = self.parameters()
        for key in params_dict.keys(): params[key][ep] = params_dict[key]
    results = {}
    if log_returns: results['returns'] = returns
    if log_ep_lengths: results['ep_lengths'] = ep_lengths
    if log_params: results['params'] = params
    return results
