from ARLLib.core import *
from abc import ABC, abstractmethod


class BaseEnv(ABC):
    """
    Base class of RL environments in which:
    1. in each time step the environment sends out an 'observation' and receives an 'action' from the agent. The agent-env
        interaction is always continual, which keeps going forever to infinite time. Different from openai-gym, in BRL
        there is no explicit and separate reward signal, nor is there the 'done' signal to indicate terminal states.
        This makes it explicit that both 'reward' and 'done' are, in their nature, not part of the problem definition,
        but rather, part of a particular solution method (which are now often handcrafted based on prior knowledge)

    2. in some (but not necessarily every) time steps, performance scores of the agent are given in info['performance'].
        'info' is meant to be a monitoring facility for evaluation purpose, which is NOT supposed to be accessed by the
        agent. Different from openai-gym, filling info['performance'] from time to time is mandatory in BRL, which makes
        it explicit about the "end metric" of the problem

    3.1 the environment will be assigned an initial state at creation time (= time 0) through the '_first_step()' routine.
        This time-0 step is then skipped in the BRL framework because it's either an terminal state (for episodic task)
        or an ignorable initial condition (for continual task). The actual agent-env interaction starts from time 1,
        after the first call of 'next_step()' which advances the time to 1 and returns obs_1 and info_1.

    3.2 For programming flexibility considerations, the '_first_step()' routine also returns obs and info, similar to
        the 'reset()' routine in openai-gym. But different from the latter, _first_step() is internal routine that is
        supposed to be called only once, and automatically by the constructor, throughout the lifetime of the environment
        object.

    4.1 the environment object will maintain a group of endogenous variables of the environment such that each of them
        has a definite value within a certain time step. The value changes of endogenous variables may be subject to
        exogenous uncertainties which are fundamentally presumed to be time-invariant random variables. Such exogenous
        randomness may be either due to incomplete characterization of the "ground truth" environment state by the
        endogenous variables alone, or due to intrinsic randomness in the physical transition of the "ground-truth"
        state (or, of course, could also be due to both apparent and intrinsic randomness).

    4.2 as with openai-gym, BRL allows environment de-randomization through specifying a 'seed' at creation time of the
        environment object. Environments supporting such de-randomization should have a proper implementation of the
        '_remove_randomness(seed)' routine which returns TRUE if and only if the exogenous has been truly de-randomized

    4.3 As an environment successfully de-randomized by a seed will output definite obs's and info's under given sequence
        of actions, the de-randomization seed and the sequence of actions by time t collectively form a representation of
        a "deterministic state" of the environment at time t. This "deterministic state" not only has definite value at
        any given time, but its value change is also deterministic over time. For environments admitting deterministic
        state, the deterministic state is tracked by member variable '_state', and can be cloned (using 'clone_deterministic_state')
        or be used to reset another environment instance of the same class (using 'set_deterministic_state')

    5. Different from openai-gym, actions from a categorical space is not represented by a natural number (which imposes
        an artificial metric over the action space); instead, a categorical action is a one-hot vector, which inducing
        the true discrete metric (https://mathworld.wolfram.com/DiscreteMetric.html). For this reason, while most space
        classes in BRL are inherited from openai-gym, the Discrete space is not used in BRL, and is replaced by the
        Categorical space.
    """

    def __init__(self, name='UnknownEnv', seed=None):
        """
        this shared constructor routine is to be called AT THE END of child-class's constructor. Example:
        def __init__(self, seed=None):
            # define self.observation_space and self.action_space in the child-class's constructor
            # put your own stuff
            BaseEnv.__init__(self, seed)
        """
        assert isinstance(self.observation_space, spaces.Space)  # self.observation_space must have been defined
        assert isinstance(self.action_space, spaces.Space)  # self.action_space must have been defined
        self.name = name
        self._initialized = False
        self._first_step(seed)

    def next_step(self, action):
        # if not self._initialized: raise RuntimeError('next_step() called before initialization.')
        if self._deterministic_state is not None: self._deterministic_state['action'].append(action)
        observation, info = self._gen_next_obs_info(action)
        return observation, info

    def render(self):
        print(self._deterministic_state)

    def set_deterministic_state(self, state):
        if isinstance(state, BaseEnv): state = state._deterministic_state
        obs, info = self._first_step(state['seed'])
        for action in state['action']: obs, info = self.next_step(action)
        return obs, info

    def clone_deterministic_state(self):
        return copy.deepcopy(self._deterministic_state)

    # internal methods
    def _first_step(self, seed=None):
        # if self._initialized: raise RuntimeError('first_step() can only be called once.')
        self._initialized = True

        """1. try resetting the exogenous part of the environment if a seed is provided (not necessarily successful)"""
        if seed is not None:
            self._seed = None if self._remove_randomness(seed) is False else seed

        """2. start tracking the det. state if the env. is de-randomized, otherwise the trace of state is missing"""
        self._deterministic_state = None if self._seed is None else {'seed': self._seed, 'action': []}

        """3. reset the endogenous variables of the environment"""
        observation, info = self._gen_first_obs_info()
        return observation, info

    def _remove_randomness(self, seed) -> bool:
        """
        try derandomizing the exogenous world and return True if succeeded, otherwise return False.
        The return value indicates whether the exogenous part of this environment has been de-randomized or not
        """
        return False

    @abstractmethod
    def _gen_first_obs_info(self):
        """
        should reset the endogenous variables to a terminal state
        """
        raise NotImplementedError

    @abstractmethod
    def _gen_next_obs_info(self, action):
        """
        should set the environment to implicitly initialize the next episode at terminal states
        """
        raise NotImplementedError

    # mandatory routine definitions in derived class
    #def _gen_first_obs_info(self):
    #def _gen_next_obs_info(self, action):

    # optional routine definitions in derived classes
    #def render(self):
    #def _remove_randomness(self, seed) -> bool:



class GymEnv(BaseEnv):
    """
    This class translates a gym environment into a BRL environment, specifically,
    1. we pack (gym_obs, gym_reward, done) into an observation tuple, from which the state signal, the reward signal,
        and the gamma signal can be extracted directly from the agent observation (at 'obs[:]', 'obs[-2]', and 'obs[-1]'
        respectively) through trivial state, reward, and credit modules
    2. we define a default agent performance as the UNDISCOUNTED sum of raw rewards from gym in the time-bounded
        episode, provided in 'info' at the last step of the episode; this default performance score can be replaced with
        environment-specific one by overriding '_calc_perf'
    3. we automatically and implicitly reset the episode at termination steps (with 'done' flag on from gym)
    4. as with other BRL envs, we support the semantic of 'deterministic environment state': a rollout starting from
        a deterministic environment state guarantees to generate deterministic trajectory under fixed action sequence
    5. we add Categorical space as supplementary to the existing gym spaces, and represent all actions from gym.Discrete
        spaces as categorical variables from this space; this new space is also useful to subsume tabular RL techniques
    6. we support built-in features like reward scaling and time-out termination to regularize the original environment
    """

    def __init__(self,
                 gym_env: gym.Env,
                 reward_scaling: float = 1.0,
                 time_out: int = None,
                 name='UnknownEnv',
                 seed=None):
        self.gym = gym_env

        self._reward_scaling = reward_scaling
        self.reward_space = array(gym_env.reward_range, dtype=np.float) * reward_scaling
        self.reward_space = spaces.Box(
            low=gym_env.reward_range[0],
            high=gym_env.reward_range[1],
            shape=(1,), dtype=np.float
        )
        # TODO: automatically normalize possibly unbounded rewards in gym to a bounded range, using the reward_scaling
        #  feature
        # assert self.reward_space.is_bounded()

        self.observation_space = spaces.Tuple((gym_env.observation_space,
                                               self.reward_space,
                                               spaces.MultiBinary(1)))

        if isinstance(gym_env.action_space, spaces.Box):
            self.action_space = spaces.Box(low=gym_env.action_space.low, high=gym_env.action_space.high, dtype=np.float)
        elif isinstance(gym_env.action_space, spaces.Discrete):
            self.action_space = spaces.Categorical(gym_env.action_space.n)
        else:
            raise NotImplementedError('GymEnv only supports Box action space or Discrete action space at the moment.')

        if time_out is not None: assert math.isfinite(time_out) and time_out > 0
        self._time_out = time_out

        self._gym_episode_t = None
        self._gym_episode_reward_sum = 0
        self._gym_to_reset = True

        BaseEnv.__init__(self, name, seed)

    def __del__(self):
        self.gym.close()

    def _first_step(self):
        gym_obs, gym_reward, gym_done, info = self.gym.reset(), 0, False, {}
        self._gym_episode_t = 1
        self._gym_episode_reward_sum = gym_reward
        self._gym_to_reset = gym_done
        return self._observe_from_gym(gym_obs, gym_reward, gym_done, info)

    def _next_step(self, action):
        # convert 'action' to gym format
        if isinstance(self.action_space, spaces.Categorical):
            action = self.action_space.onehot(action)  # gym use natural numbers to encode actions with discrete metric

        if self._gym_to_reset:
            gym_obs, gym_reward, gym_done, info = self.gym.reset(), 0, False, {}
            self._gym_episode_t = 1
            self._gym_episode_reward_sum = gym_reward
            self._gym_to_reset = gym_done
        else:
            gym_obs, gym_reward, gym_done, info = self.gym.step(action)
            self._gym_episode_t += 1
            self._gym_episode_reward_sum += gym_reward
            self._gym_to_reset = gym_done or (self._time_out and self._gym_episode_t >= self._time_out)

        return self._observe_from_gym(gym_obs, gym_reward, gym_done, info)

    def _observe_from_gym(self, gym_obs, gym_reward, gym_done, info):
        # # add some noise to gym observation
        # noise = self.rng.uniform(low=-1e-3, high=1e-3, sample_size=len(state))
        # # noise = (self.env.observation_space.high - self.env.observation_space.low) * noise
        # state = np.clip(state + noise, a_min=self.env.observation_space.low, a_max=self.env.observation_space.high)

        """gym environments provide hint about reward and termination as input to the agent, from which the gym wrapper
        directly calculates r_t and gamma_t and "injects" them into x_t as the two new dimensions concatenated behind
        the raw gym observation"""
        scaled_reward = gym_reward * self._reward_scaling
        is_terminal = gym_done or (self._time_out and self._gym_episode_t >= self._time_out)
        obs = (gym_obs,
               array([scaled_reward], dtype=np.float),
               array([1 if is_terminal else 0], dtype=np.int8))
        # obs = array(gym_obs, dtype=np.float)
        # obs = np.append(obs, np.float(scaled_reward))
        # obs = np.append(obs, np.int8(1 if is_terminal else 0))

        """keep records of raw reward and done signals from gym for the possible stand-alone analysis"""
        if info is None: info = {}
        #info['game_score'] = self._gym_episode_reward_sum
        #info['game_over'] = self._gym_to_reset
        #info['game_length'] = self._gym_episode_t
        info['gym_reward'] = gym_reward
        info['gym_done'] = gym_done
        info['episode_time'] = self._gym_episode_t
        self._calc_perf(info)

        return obs, info

    def _calc_perf(self, info, *args, **kwargs):
        if self._gym_to_reset:
            info['performance'] = self._gym_episode_reward_sum

    def render_state(self, mode='human'):
        return self.gym.render(mode)

    def _remove_randomness(self, seed) -> bool:
        # remove stochastic uncertainty by using the built-in facility seed() in gym
        self.gym.seed(seed)
        self.gym.observation_space.seed(seed+1)
        self.gym.action_space.seed(seed+2)
        self.observation_space.seed(seed + 3)
        self.action_space.seed(seed + 4)
        self.reward_space.seed(seed + 5)

        # TODO: more sanity check here may be needed as seeding the prng's in the GymEnv class does not necessarily
        #  remove all uncertainties in the environment dynamic, especially not when 'gym_env' is physically connected to
        #  part of the real world
        return True


class Classification(BaseEnv):
    """
    A RL environment in which a classifier agent is rewarded based on correctness of its decisions.

    This RL paradigm of classification is fundamentally different from the popular supervised-learning paradigm in that
    the agent is not exposed to the groundtruth decision (or to a sample of the groundtruth), so she is not able to
    measure the distance from a candidate distribution over class labels to the groundtruth distribution -- a key info
    that most supervised learning algorithms nowadays rely on to define differentiable objective functions. What is
    available to the agent in a RL-based classification environment is, instead, an (possibly delayed) evaluative
    feedback about only the actual decisions that the agent has made. From the evaluative feedback the agent is aware of
    only *whether* it has done something wrong, but there is no hint on *how far* it is from doing the right thing.
    Moreover, the evaluation score the agent obtains from the RL environment is generally not differentiable w.r.t. the
    parameters of the agent's decision-making model -- in fact the agent (or the programmer behind the agent) does not
    expect to know a closed form of the reward function at all -- so if one considers the evaluation score itself as the
    objective function of parameter optimization, that objective is not required to be differentiable (and it doesn't
    have a known closed form in general cases).

    Despite of fundamental difference in the way the agent learns, this RL environment is powered by standard supervised
    learning data. Specifically, supervised learning data set consists of (x,y) pairs where x is the sample input, and y
    is a class label for that input, serving as the (sampled) groundtruth. The agent interacts with the environment by
    repeatedly doing the following: observing a sample input x_t, outputting a class label a_t as the agent's
    classification decision on x_t, then in the next time step observing a score r_{t+1} (typically a binary score) as
    the obtained evaluation of decision a_t. Such a two-step interaction upon the single (x,y) pair forms an episode,
    and an infinite loop of such two-step episodes forms the entire RL rollout.

    The environment feeds samples to the RL agent in the same way as how data is provided to a standard supervised
    learning agent. Specifically, the data is split into a training set and a testing set, and the environment can run
    under either the 'train' mode or the 'test' mode; in either case the environment repeatedly samples the corresponding
    data partition, uniformly without replacement, and uses each sample thus obtained to form an episode. The env. will
    internally reset the sampling when an epoch of the partition scanning is complete, without explicitly notifying the
    agent (but an epoch id will show up in 'info' at the first step of the first episode/sample of that epoch).

    When the RL environment runs with supervised-learning data, the reward function is by default set to be r_{t+1} =
    I(a_t=y_t), where I() is the indicator function. However, the reward function can be arbitrarily re-defined in child
    classes of this base class, and importantly, it does NOT necessarily depend on a groundtruth label. For example, the
    reward function of a child class in this framework may choose to open a prompt to ask for a real human to judge the
    correctness/quality of an specific classification result the RL agent gives, forming a human-in-the-loop learning
    environment.

    The format of observation of this environment is implemented in the way that it can be converted into a gym environment
    through WrapToGym with the default StateGym, RewardGym, and CreditGym modules
    """
    def __init__(self, mode='train', name='UnknownEnv', seed=None):
        self.num_class, self.observation_space, self.action_space = None, None, None
        self.x_train, self.y_train, self.x_test, self.y_test = self.load_data()
        self.observation_space = spaces.Tuple(
            (self.observation_space,
            spaces.Box(low=0, high=1, shape=(1,), dtype=np.float),
            spaces.MultiBinary(1))
        )

        self.epoch_id = None
        self.sample_id = None
        self.x = None
        self.y = None
        self.mode = mode
        BaseEnv.__init__(self, name, seed)

    def _first_step(self):
        assert self.mode == 'train' or self.mode == 'test', \
            "%s is run in unknown mode: " % (self.__class__.__name__, self.mode)
        if self.mode == 'train':
            self.x, self.y = (self.x_train, self.y_train)
        else:
            self.x, self.y = (self.x_test, self.y_test)

        self.epoch_id = 0
        self._prepare_epoch()

        self.sample_id = self.shuffled_id[self.epoch_pos]
        obs = (self.x[self.sample_id], array([0], dtype=np.float), array([0], dtype=np.int8))
        return obs, {'epoch': 1}

    def _next_step(self, action):
        reward, info = np.float(0), {}
        if self.y_hat is not None:  # step 0 of the episode
            self.y_hat = None
            self.epoch_pos += 1
            if self.epoch_pos >= len(self.x):
                self._prepare_epoch()
                info['epoch'] = self.epoch_id
            is_terminal = np.int8(0)
        else:  # step 1 of the episode
            self.y_hat = onehot_decoding(action)
            reward = self.reward_func()
            is_terminal = np.int8(1)
            info['performance'] = reward

        self.sample_id = self.shuffled_id[self.epoch_pos]
        obs = (self.x[self.sample_id], array([reward]), array([is_terminal]))
        return obs, info

    def _prepare_epoch(self):
        self.epoch_id += 1
        self.epoch_pos = 0
        self.y_hat = None

        # TODO: shuffle data
        self.shuffled_id = [i for i in range(len(self.x))]

    # the following routines are to be defined in child class
    @abstractmethod
    def load_data(self):
        """
        set self.observation_range and self.action_range
        :return: (x_train, y_train, x_test, y_test)
        """
        raise NotImplementedError

    # the following routines are to be re-defined in child class
    def reward_func(self):
        correct = (self.y_hat == self.y[self.sample_id])
        return 1 if correct else 0

    def render_state(self, mode='human'):
        print("mode: {}, epoch: {}, {} samples finished.".format(self.mode, self.epoch_id, self.epoch_pos))

    def _remove_randomness(self, seed) -> bool:
        return False


"""
example usage and simple unit test
"""
if __name__ == '__main__':
    # classic image classification environments, with discrete action space
    from ARLLib.environments.mnist import MNIST
    env = MNIST('train', seed=0)
    print('\n' + env.name)

    perf = RV("accuracy")

    obs, info = env.first_step()
    print(info)
    env.render_state()

    while(True):
        action = env.action_space.sample()
        assert env.action_space.contains(action)
        obs, info = env.next_step(action)
        assert env.observation_space.contains(obs)

        if info.get('performance') is not None:
            # here is the end of an episode
            perf.append(info['performance'])
            assert obs[-1] == 1

        if info.get('epoch') is not None:
            # here is the start of an epoch
            print(perf)
            if env.epoch_id > 3: break
            print(info)
            env.render_state()

    # gym native version of cartpole, with discrete action space
    env = GymEnv(gym.make('CartPole-v0'), name='gym.CartPole-v0', reward_scaling=2, time_out=15, seed=1)
    print('\n'+env.name)
    t = 0
    obs, info = env.first_step()
    env.render_state()

    perf = RV(env.name)
    while perf.size() < 20:
        action = env.action_space.sample()
        # action = onehot_encoding(0, env.action_range.shape)
        assert env.action_space.contains(action)

        t += 1
        obs, info = env.next_step(action)
        assert env.observation_space.contains(obs)
        env.render_state()

        if info.get('performance') is not None:
            # print("\r")
            print("t: {}, perf: {}".format(t, info['performance'], end=' ', flush=True))
            perf.append(info['performance'])
            assert obs[-1] == 1
    print(perf)
    del env

    # custom version of inverted pendulum using CartPole GUI, with continuous action space
    from ARLLib.environments.inverted_pendulum import InvertedPendulumGym
    env = GymEnv(InvertedPendulumGym(), name='InvertedPendulum', seed=2)
    print('\n' + env.name)
    env_test = GymEnv(InvertedPendulumGym(), seed=3)
    perf = RV(env.name)

    t = 0
    obs, info = env.first_step(1234)
    env.render_state()

    while perf.size() < 20:
        action = env.action_space.sample()
        # action = np.zeros(env.action_range.shape)
        assert env.action_space.contains(action)

        t += 1
        obs, info = env.next_step(action)
        assert env.observation_space.contains(obs)
        env.render_state()

        if t % 100 == 0:
            obs, info = env_test.set_state(env)
            env_test.render_state()

        if info.get('performance') is not None:
            # print("\r")
            print("t: {}, perf: {}".format(t, info['performance'], end=' ', flush=True))
            perf.append(info['performance'])
            assert obs[-1] == 1
    print(perf)
    del env, env_test

    # import roboschool
    # from OpenGL import GLU
    # env = GymEnv(gym.make('RoboschoolInvertedPendulum-v1'))
    # roboschool version of inverted pendulum implemented by pybullet, with continuous action space
    import pybulletgym
    env = GymEnv(gym.make('InvertedPendulumPyBulletEnv-v0'), name='pybulletgym.InvertedPendulumPyBulletEnv-v0', seed=4)
    print('\n'+env.name)
    perf = RV(env.name)
    env.render_state()

    obs, info = env.first_step(4321)
    env.render_state()

    while perf.size()<100:
        action = env.action_space.sample()
        # action = np.zeros(env.action_range.shape)

        obs, info = env.next_step(action)
        env.render_state()

        if info.get('performance') is not None:
            print("\rperf: ", info['performance'], end=' ', flush=True)
            perf.append(info['performance'])
    print('\n')
    print(perf)
    del env

""" expected console output:

commit #:  8228584c7cf9995a55bee4955679ea05ddeb5fe3
preparing the MNIST environment
------------------------------
start loading MNIST data
finish loading MNIST data
image type: <class 'numpy.ndarray'>  [uint8 , (28, 28)]
label type: <class 'numpy.int64'>  [int64 , ()]
#image in the train set:  60000
#image in the  test set:  10000
------------------------------

MNIST_train
{'epoch': 1}
accuracy	mean= 0.10 (9.86e-02 , 1.03e-01)	std= 0.30	n= 60000
{'epoch': 2}
accuracy	mean= 0.10 (9.88e-02 , 1.02e-01)	std= 0.30	n= 120000
{'epoch': 3}
accuracy	mean= 0.10 (9.90e-02 , 1.01e-01)	std= 0.30	n= 180000

gym.CartPole-v0
t: 14, perf: 14.0
t: 29, perf: 14.0
t: 44, perf: 14.0
t: 59, perf: 14.0
t: 74, perf: 14.0
t: 89, perf: 14.0
t: 104, perf: 14.0
t: 119, perf: 14.0
t: 134, perf: 14.0
t: 149, perf: 14.0
t: 164, perf: 14.0
t: 179, perf: 14.0
t: 194, perf: 14.0
t: 209, perf: 14.0
t: 224, perf: 14.0
t: 239, perf: 14.0
t: 254, perf: 14.0
t: 265, perf: 10.0
t: 278, perf: 12.0
t: 293, perf: 14.0
gym.CartPole-v0	mean= 13.70 (1.33e+01 , 1.41e+01)	std= 0.98	n= 20

InvertedPendulum
t: 33, perf: 32.0
t: 43, perf: 8.0
t: 53, perf: 8.0
t: 71, perf: 16.0
t: 93, perf: 20.0
t: 106, perf: 11.0
t: 119, perf: 11.0
t: 128, perf: 7.0
t: 140, perf: 10.0
t: 154, perf: 12.0
t: 172, perf: 16.0
t: 191, perf: 17.0
t: 216, perf: 23.0
t: 230, perf: 12.0
t: 241, perf: 9.0
t: 260, perf: 17.0
t: 277, perf: 15.0
t: 307, perf: 28.0
t: 328, perf: 19.0
t: 350, perf: 20.0
InvertedPendulum	mean= 15.55 (1.31e+01 , 1.80e+01)	std= 6.75	n= 20
pybullet build time: Jan  7 2020 19:15:59

pybulletgym.InvertedPendulumPyBulletEnv-v0
starting thread 0
started testThreads thread 0 with threadHandle 00000000000007AC
argc=2
argv[0] = --unused
argv[1] = --start_demo_name=Physics Server
ExampleBrowserThreadFunc started
Version = 4.6.0 - Build 27.20.100.8187
Vendor = Intel
Renderer = Intel(R) UHD Graphics 620
b3Printf: Selected demo: Physics Server
starting thread 0
started MotionThreads thread 0 with threadHandle 00000000000005C4
MotionThreadFunc thread started
perf:  24.0 

pybulletgym.InvertedPendulumPyBulletEnv-v0	mean= 26.51 (2.45e+01 , 2.85e+01)	std= 11.96	n= 100
numActiveThreads = 0
stopping threads
Thread with taskId 0 with handle 00000000000005C4 exiting
Thread TERMINATED
finished
numActiveThreads = 0
btShutDownExampleBrowser stopping threads
Thread with taskId 0 with handle 00000000000007AC exiting
Thread TERMINATED
"""

