import numpy as np
import matplotlib.pyplot as plt

from Environments.MDPy import MDP

class TabularEnvironment():
  def __init__(self):
    ''' Constructor to handle any arguments
    Args:
      *args: Environment parameters
    '''
    raise NotImplementedError

  def init(self):
    ''' Environment initialization
    '''
    raise NotImplementedError

  def init_state(self):
    ''' Compute initial environment state
    Returns:
      s0 (int): Initial environment state
    '''
    raise NotImplementedError

  def sample(self, s, a):
    ''' Produce next state and reward given a state-action pair
    Args:
      s (any): An environment state
      a (int): Action to take
    Returns:
      sp (any): Next state
      r (float): Reward received
      T (bool): Terminal state
    '''
    raise NotImplementedError

  @property
  def n_states(self):
    ''' Number of states in the environment
    Returns:
      n_states (int): Number of states in the environment
    '''
    raise NotImplementedError

  @property
  def n_actions(self):
    ''' Number of actions in the environment
    Returns:
      n_actions (int): Number of actions in the environment
    '''
    raise NotImplementedError

  @property
  def max_steps(self):
    ''' Episodic time limit
    Returns:
      max_steps (int): Episodic time limit (inf for continuing)
    '''
    raise NotImplementedError

  def state_id(self, s):
    ''' Computes a state ID for environments that use non-integer states internally (e.g., coordinates)
    Args:
      s (any): An environment state
    Returns:
      s_id (int): A unique integer identifer
    '''
    return s

  def reset(self):
    ''' Environment reset (set and return initial environment state)
    Returns:
      s0 (int): Initial environment state
    '''
    self._s = self.init_state()
    self._t = 0
    return self.state_id(self._s)

  def step(self, a):
    ''' Sample environment given an action for the current environment state
    Returns:
      sp (int): Next state
      r (float): Reward received
      T (bool): Terminal state
    '''
    sp, r, T = self.sample(self._s, a)
    self._s = sp
    self._t += 1
    return self.state_id(sp), r, T, self._t == self.max_steps
