"""
Superclass for all semi-supervised envs

These are toy problems but the principle is useful -- RL agents in the real world
will likely be learning from an inconsistent signal. For example, a human might
use a clicker to reward an RL agent but likely wouldn't do so with perfect consistency.

Note: In all semisupervised environmenvts, we judge the RL agent based on their total
true_reward, not their percieved_reward. This means that even if the true_reward happens to
not be shown to the agent for an entire episode, the agent is still being judged
and should still perform as well as possible.
"""
from rand_param_envs import gym


class SemisuperEnv(gym.Env):
    def step(self, action):
        assert self.action_space.contains(action)

        observation, true_reward, done, info = self._step(action)
        info['true_reward'] = true_reward  # Used by monitor for evaluating performance

        assert self.observation_space.contains(observation)

        perceived_reward = self._distort_reward(true_reward)
        return observation, perceived_reward, done, info


"""
true_reward is only shown to the agent 1/10th of the time.
"""


class SemisuperRandomEnv(SemisuperEnv):
    PROB_GET_REWARD = 0.1

    def _distort_reward(self, true_reward):
        if self.np_random.uniform() < SemisuperRandomEnv.PROB_GET_REWARD:
            return true_reward
        else:
            return 0


"""
semisuper_pendulum_noise is the pendulum task but where reward function is noisy.
"""


class SemisuperNoiseEnv(SemisuperEnv):
    NOISE_STANDARD_DEVIATION = 3.0

    def _distort_reward(self, true_reward):
        return true_reward + self.np_random.normal(scale=SemisuperNoiseEnv.NOISE_STANDARD_DEVIATION)


"""
semisuper_pendulum_decay is the pendulum task but where the reward function
is given to the agent less and less often over time.
"""


class SemisuperDecayEnv(SemisuperEnv):
    DECAY_RATE = 0.999

    def __init__(self):
        super(SemisuperDecayEnv, self).__init__()

        # This probability is only reset when you create a new instance of this env:
        self.prob_get_reward = 1.0

    def _distort_reward(self, true_reward):
        self.prob_get_reward *= SemisuperDecayEnv.DECAY_RATE

        # Then we compute the perceived_reward
        if self.np_random.uniform() < self.prob_get_reward:
            return true_reward
        else:
            return 0


"""
Now let's make some envs!
"""
from rand_param_envs.gym.envs.classic_control.pendulum import PendulumEnv


class SemisuperPendulumNoiseEnv(SemisuperNoiseEnv, PendulumEnv): pass


class SemisuperPendulumRandomEnv(SemisuperRandomEnv, PendulumEnv): pass


class SemisuperPendulumDecayEnv(SemisuperDecayEnv, PendulumEnv): pass
