"""Tests for epsilon greedy policy."""
import collections
import pickle

import numpy as np

from garage.np.exploration_policies import EpsilonGreedyPolicy

from tests.fixtures.envs.dummy import DummyDiscreteEnv


class SimplePolicy:
    """Simple policy for testing."""

    def __init__(self, env_spec):
        self.env_spec = env_spec

    def get_action(self, _):
        return self.env_spec.action_space.sample(), dict()

    def get_actions(self, observations):
        return np.full(len(observations),
                       self.env_spec.action_space.sample()), dict()

    def get_param_values(self):
        pass

    def set_param_values(self, params):
        del params


class TestEpsilonGreedyPolicy:

    def setup_method(self):
        self.env = DummyDiscreteEnv()
        self.policy = SimplePolicy(env_spec=self.env)
        self.epsilon_greedy_policy = EpsilonGreedyPolicy(env_spec=self.env,
                                                         policy=self.policy,
                                                         total_timesteps=100,
                                                         max_epsilon=1.0,
                                                         min_epsilon=0.02,
                                                         decay_ratio=0.1)

        self.env.reset()

    def test_epsilon_greedy_policy(self):
        obs, _, _, _ = self.env.step(1)

        action, _ = self.epsilon_greedy_policy.get_action(obs)
        assert self.env.action_space.contains(action)

        # epsilon decay by 1 step, new epsilon = 1 - 0.098 = 0.902
        random_rate = np.random.random(
            100000) < self.epsilon_greedy_policy._epsilon()
        assert np.isclose([0.902], [sum(random_rate) / 100000], atol=0.01)

        actions, _ = self.epsilon_greedy_policy.get_actions([obs] * 5)

        # epsilon decay by 6 steps in total
        # new epsilon = 1 - 6 * 0.098 = 0.412
        random_rate = np.random.random(
            100000) < self.epsilon_greedy_policy._epsilon()
        assert np.isclose([0.412], [sum(random_rate) / 100000], atol=0.01)

        for action in actions:
            assert self.env.action_space.contains(action)

    def test_set_param(self):
        params = self.epsilon_greedy_policy.get_param_values()
        params['total_env_steps'] = 6
        self.epsilon_greedy_policy.set_param_values(params)
        assert np.isclose(self.epsilon_greedy_policy._epsilon(), 0.412)

    def test_update(self):
        DummyBatch = collections.namedtuple('EpisodeBatch', ['lengths'])
        batch = DummyBatch(np.array([1, 2, 3]))
        self.epsilon_greedy_policy.update(batch)
        assert np.isclose(self.epsilon_greedy_policy._epsilon(), 0.412)

    def test_epsilon_greedy_policy_is_pickleable(self):
        obs, _, _, _ = self.env.step(1)
        for _ in range(5):
            self.epsilon_greedy_policy.get_action(obs)

        h_data = pickle.dumps(self.epsilon_greedy_policy)
        policy = pickle.loads(h_data)
        assert policy._epsilon() == self.epsilon_greedy_policy._epsilon()
