import pytest
import torch

from offline_rl.envs.custom_reacher_env import CustomReacherEnv, CustomReacherEnvRewardModel


class TestCustomReacherEnv:
    @pytest.mark.parametrize("frame_skip", [10, 20])
    def test_reset_step(self, frame_skip):
        env = CustomReacherEnv(frame_skip=frame_skip)
        assert env.frame_skip == frame_skip
        env.reset()
        for _ in range(40):
            env.step(env.action_space.sample())


class TestCustomReacherEnvRewardModel:
    def test_goal_reward(self):
        env = CustomReacherEnv(frame_skip=20)
        reward_model = CustomReacherEnvRewardModel(
            env.observation_space,
            env.action_space,
            reward_dist_factor=0,
            reward_ctrl_factor=0,
            reward_goal_factor=1,
            shaping_factor=0,
            shaping_discount=0,
        )

        action = torch.zeros(1, 2)
        next_state = torch.zeros(1, 11)

        outside_state = torch.ones(1, 11)
        rewards = reward_model.reward(outside_state, action, next_state, None)
        assert rewards.shape == (1, 1)
        assert rewards.item() == 0.0

        inside_state = torch.zeros(1, 11)
        rewards = reward_model.reward(inside_state, action, next_state, None)
        assert rewards.shape == (1, 1)
        assert rewards.item() == 1.0

    def test_shaping(self):
        env = CustomReacherEnv(frame_skip=20)
        reward_model = CustomReacherEnvRewardModel(
            env.observation_space,
            env.action_space,
            reward_dist_factor=0,
            reward_ctrl_factor=0,
            reward_goal_factor=0,
            shaping_factor=1,
            shaping_discount=1,
        )
        state = torch.zeros(1, 11)
        action = torch.zeros(1, 2)
        next_state = torch.zeros(1, 11)
        next_state[0, -1] = 5
        rewards = reward_model.reward(state, action, next_state, None)
        assert rewards.item() == 5 - 0

        rewards = reward_model.reward(next_state, action, state, None)
        assert rewards.item() == 0 - 5
