import os
import sys
sys.path.append(os.getcwd())

import numpy as np
import matplotlib.pyplot as plt

from src.envs.Environments.TabularEnvironment import TabularEnvironment
from src.envs.Environments.MDPy import MDP

class AdaptChain(TabularEnvironment):
  def __init__(self, n=10):
    ''' Constructor to handle any arguments, and things to only be initialized once
    '''
    self._mdp = MDP(n + 1, 2)
    for s in range(n):
      self._mdp.add_transition(s, 0, 1.0, max(0, s - 1), -1)
      self._mdp.add_transition(s, 1, 1.0, s + 1, 0 if s < n - 1 else 10)
    self._mdp.build()

  def init(self):
    ''' Environment initialization
    '''
    pass

  def init_state(self):
    ''' Compute initial environment state
    Returns:
      s0 (int): Initial environment state
    '''
    return 0

  def sample(self, s, a):
    ''' Produce next state and reward given a state-action pair
    Args:
      s (any): An environment state
      a (int): Action to take
    Returns:
      sp (any): Next state
      r (float): Reward received
      T (bool): Terminal state
    '''
    return self._mdp.sample(s, a)

  @property
  def n_states(self):
    ''' Number of states in the environment
    Returns:
      n_states (int): Number of states in the environment
    '''
    return self._mdp.n_states

  @property
  def n_actions(self):
    ''' Number of actions in the environment
    Returns:
      n_actions (int): Number of actions in the environment
    '''
    return self._mdp.n_actions

  @property
  def max_steps(self):
    ''' Episodic time limit
    Returns:
      max_steps (int): Episodic time limit (inf for continuing)
    '''
    return np.iinfo(int).max

  def q_star(self, γ):
    ''' Compute optimal action-value function
    Args:
      γ (float): Discount factor
    Returns:
      q_star (numpy array): Optimal action-value function q_star[s, a]
    '''
    return self._mdp.get_q_star(γ)

  def v_pi(self, π, γ):
    ''' Compute a policy's state-value function
    Args:
      π (numpy array): Policy π[s, a]
      γ (float): Discount factor
    Returns:
      v_pi (numpy array): Policy's state-value function v_pi[s]
    '''
    return self._mdp.get_v_fixed_pi(π, γ)