# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Tests for dm_control.suite domains."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Internal dependencies.
from absl.testing import absltest
from absl.testing import parameterized
from dm_control import suite
from dm_control.rl import control
import mock
import numpy as np
import six
from six.moves import range
from six.moves import zip


def uniform_random_policy(action_spec, random=None):
    lower_bounds = action_spec.minimum
    upper_bounds = action_spec.maximum
    # Draw values between -1 and 1 for unbounded actions.
    lower_bounds = np.where(np.isinf(lower_bounds), -1.0, lower_bounds)
    upper_bounds = np.where(np.isinf(upper_bounds), 1.0, upper_bounds)
    random_state = np.random.RandomState(random)

    def policy(time_step):
        del time_step  # Unused.
        return random_state.uniform(lower_bounds, upper_bounds)

    return policy


def step_environment(env, policy, num_episodes=5, max_steps_per_episode=10):
    for _ in range(num_episodes):
        step_count = 0
        time_step = env.reset()
        yield time_step
        while not time_step.last():
            action = policy(time_step)
            time_step = env.step(action)
            step_count += 1
            yield time_step
            if step_count >= max_steps_per_episode:
                break


def make_trajectory(domain, task, seed, **trajectory_kwargs):
    env = suite.load(domain, task, task_kwargs={"random": seed})
    policy = uniform_random_policy(env.action_spec(), random=seed)
    return step_environment(env, policy, **trajectory_kwargs)


class DomainTest(parameterized.TestCase):
    """Tests run on all the tasks registered."""

    def test_constants(self):
        num_tasks = sum(len(tasks) for tasks in six.itervalues(suite.TASKS_BY_DOMAIN))

        self.assertLen(suite.ALL_TASKS, num_tasks)

    def _validate_observation(self, observation_dict, observation_spec):
        obs = observation_dict.copy()
        for name, spec in six.iteritems(observation_spec):
            arr = obs.pop(name)
            self.assertEqual(arr.shape, spec.shape)
            self.assertEqual(arr.dtype, spec.dtype)
            self.assertTrue(
                np.all(np.isfinite(arr)),
                msg="{!r} has non-finite value(s): {!r}".format(name, arr),
            )
        self.assertEmpty(
            obs,
            msg="Observation contains arrays(s) that are not in the spec: {!r}".format(
                obs
            ),
        )

    def _validate_reward_range(self, time_step):
        if time_step.first():
            self.assertIsNone(time_step.reward)
        else:
            self.assertIsInstance(time_step.reward, float)
            self.assertBetween(time_step.reward, 0, 1)

    def _validate_discount(self, time_step):
        if time_step.first():
            self.assertIsNone(time_step.discount)
        else:
            self.assertIsInstance(time_step.discount, float)
            self.assertBetween(time_step.discount, 0, 1)

    def _validate_control_range(self, lower_bounds, upper_bounds):
        for b in lower_bounds:
            self.assertEqual(b, -1.0)
        for b in upper_bounds:
            self.assertEqual(b, 1.0)

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_components_have_names(self, domain, task):
        env = suite.load(domain, task)
        model = env.physics.model

        object_types_and_size_fields = [
            ("body", "nbody"),
            ("joint", "njnt"),
            ("geom", "ngeom"),
            ("site", "nsite"),
            ("camera", "ncam"),
            ("light", "nlight"),
            ("mesh", "nmesh"),
            ("hfield", "nhfield"),
            ("texture", "ntex"),
            ("material", "nmat"),
            ("equality", "neq"),
            ("tendon", "ntendon"),
            ("actuator", "nu"),
            ("sensor", "nsensor"),
            ("numeric", "nnumeric"),
            ("text", "ntext"),
            ("tuple", "ntuple"),
        ]
        for object_type, size_field in object_types_and_size_fields:
            for idx in range(getattr(model, size_field)):
                object_name = model.id2name(idx, object_type)
                self.assertNotEqual(
                    object_name,
                    "",
                    msg="Model {!r} contains unnamed {!r} with ID {}.".format(
                        model.name, object_type, idx
                    ),
                )

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_model_has_at_least_2_cameras(self, domain, task):
        env = suite.load(domain, task)
        model = env.physics.model
        self.assertGreaterEqual(
            model.ncam,
            2,
            "Model {!r} should have at least 2 cameras, has {}.".format(
                model.name, model.ncam
            ),
        )

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_task_conforms_to_spec(self, domain, task):
        """Tests that the environment timesteps conform to specifications."""
        is_benchmark = (domain, task) in suite.BENCHMARKING
        env = suite.load(domain, task)
        observation_spec = env.observation_spec()
        action_spec = env.action_spec()

        # Check action bounds.
        if is_benchmark:
            self._validate_control_range(action_spec.minimum, action_spec.maximum)

        # Step through the environment, applying random actions sampled within the
        # valid range and check the observations, rewards, and discounts.
        policy = uniform_random_policy(action_spec)
        for time_step in step_environment(env, policy):
            self._validate_observation(time_step.observation, observation_spec)
            self._validate_discount(time_step)
            if is_benchmark:
                self._validate_reward_range(time_step)

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_environment_is_deterministic(self, domain, task):
        """Tests that identical seeds and actions produce identical trajectories."""
        seed = 0
        # Iterate over two trajectories generated using identical sequences of
        # random actions, and with identical task random states. Check that the
        # observations, rewards, discounts and step types are identical.
        trajectory1 = make_trajectory(domain=domain, task=task, seed=seed)
        trajectory2 = make_trajectory(domain=domain, task=task, seed=seed)
        for time_step1, time_step2 in zip(trajectory1, trajectory2):
            self.assertEqual(time_step1.step_type, time_step2.step_type)
            self.assertEqual(time_step1.reward, time_step2.reward)
            self.assertEqual(time_step1.discount, time_step2.discount)
            for key in six.iterkeys(time_step1.observation):
                np.testing.assert_array_equal(
                    time_step1.observation[key],
                    time_step2.observation[key],
                    err_msg="Observation {!r} is not equal.".format(key),
                )

    def assertCorrectColors(self, physics, reward):
        colors = physics.named.model.mat_rgba
        for material_name in ("self", "effector", "target"):
            highlight = colors[material_name + "_highlight"]
            default = colors[material_name + "_default"]
            blend_coef = reward ** 4
            expected = blend_coef * highlight + (1.0 - blend_coef) * default
            actual = colors[material_name]
            err_msg = (
                "Material {!r} has unexpected color.\nExpected: {!r}\n"
                "Actual: {!r}".format(material_name, expected, actual)
            )
            np.testing.assert_array_almost_equal(expected, actual, err_msg=err_msg)

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_visualize_reward(self, domain, task):
        env = suite.load(domain, task)
        env.task.visualize_reward = True
        action = np.zeros(env.action_spec().shape)

        with mock.patch.object(env.task, "get_reward") as mock_get_reward:
            mock_get_reward.return_value = -3.0  # Rewards < 0 should be clipped.
            env.reset()
            mock_get_reward.assert_called_with(env.physics)
            self.assertCorrectColors(env.physics, reward=0.0)

            mock_get_reward.reset_mock()
            mock_get_reward.return_value = 0.5
            env.step(action)
            mock_get_reward.assert_called_with(env.physics)
            self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)

            mock_get_reward.reset_mock()
            mock_get_reward.return_value = 2.0  # Rewards > 1 should be clipped.
            env.step(action)
            mock_get_reward.assert_called_with(env.physics)
            self.assertCorrectColors(env.physics, reward=1.0)

            mock_get_reward.reset_mock()
            mock_get_reward.return_value = 0.25
            env.reset()
            mock_get_reward.assert_called_with(env.physics)
            self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value)

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_task_supports_environment_kwargs(self, domain, task):
        env = suite.load(domain, task, environment_kwargs=dict(flat_observation=True))
        # Check that the kwargs are actually passed through to the environment.
        self.assertSetEqual(set(env.observation_spec()), {control.FLAT_OBSERVATION_KEY})

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_observation_arrays_dont_share_memory(self, domain, task):
        env = suite.load(domain, task)
        first_timestep = env.reset()
        action = np.zeros(env.action_spec().shape)
        second_timestep = env.step(action)
        for name, first_array in six.iteritems(first_timestep.observation):
            second_array = second_timestep.observation[name]
            self.assertFalse(
                np.may_share_memory(first_array, second_array),
                msg="Consecutive observations of {!r} may share memory.".format(name),
            )

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_observations_dont_contain_constant_elements(self, domain, task):
        env = suite.load(domain, task)
        trajectory = make_trajectory(
            domain=domain, task=task, seed=0, num_episodes=2, max_steps_per_episode=1000
        )
        observations = {name: [] for name in env.observation_spec()}
        for time_step in trajectory:
            for name, array in six.iteritems(time_step.observation):
                observations[name].append(array)

        failures = []

        for name, array_list in six.iteritems(observations):
            # Sampling random uniform actions generally isn't sufficient to trigger
            # these touch sensors.
            if (
                domain in ("manipulator", "stacker")
                and name == "touch"
                or domain == "quadruped"
                and name == "force_torque"
            ):
                continue
            stacked_arrays = np.array(array_list)
            is_constant = np.all(stacked_arrays == stacked_arrays[0], axis=0)
            has_constant_elements = (
                is_constant if np.isscalar(is_constant) else np.any(is_constant)
            )
            if has_constant_elements:
                failures.append((name, is_constant))

        self.assertEmpty(
            failures,
            msg="The following observation(s) contain constant elements:\n{}".format(
                "\n".join(
                    ":\t".join([name, str(is_constant)])
                    for (name, is_constant) in failures
                )
            ),
        )

    @parameterized.parameters(*suite.ALL_TASKS)
    def test_initial_state_is_randomized(self, domain, task):
        env = suite.load(domain, task, task_kwargs={"random": 42})
        obs1 = env.reset().observation
        obs2 = env.reset().observation
        self.assertFalse(
            all(np.all(obs1[k] == obs2[k]) for k in obs1),
            "Two consecutive initial states have identical observations.\n"
            "First: {}\nSecond: {}".format(obs1, obs2),
        )


if __name__ == "__main__":
    absltest.main()
