import gym
import numpy as np
import pytest
from ray.rllib.policy.sample_batch import SampleBatch
import torch

from offline_rl.envs.bouncing_balls_env import (
    BouncingBallsEnv,
    BouncingBallsEnvFeasibilityRewardWrapper,
    BouncingBallsEnvRewardModel,
)
from offline_rl.utils.testing.bouncing_balls_env import (
    dummy_bouncing_balls_act_space,
    dummy_bouncing_balls_obs_space,
    get_colliding_states,
    get_dist_from_goal_states,
    get_goal_reached_states,
    get_noncolliding_states,
    get_noop_actions,
    get_one_actions,
)
from offline_rl.utils.testing.envs import sample_rollout, RandomPolicy
from offline_rl.utils.testing.rewards import ConstantRewardModel


class TestBouncingBallsEnv:
    def test_reset(self):
        num_balls = 2
        side_length = 50
        env = BouncingBallsEnv(
            min_num_balls=num_balls,
            max_num_balls=num_balls,
            side_length=side_length,
            obs_space_mode="dict",
            is_continuing=False,
        )
        state = env.reset()
        assert len(state["balls"]) == 2
        for ball in state["balls"]:
            assert np.all(ball["position"] < side_length)

    def test_wall_bouncing(self):
        tolerance = 2.0
        side_length = 20
        env = BouncingBallsEnv(
            side_length=side_length,
            dt=0.5,
            obs_space_mode="dict",
            is_continuing=False,
        )
        max_steps = 200
        env.reset()
        for _ in range(max_steps):
            action = np.zeros(2)
            state, _, _, _ = env.step(action)
            for ball in state["balls"]:
                assert np.all(ball["position"] < side_length + tolerance)

    def test_flat_obs_space(self):
        min_num_balls = 2
        max_num_balls = 10
        side_length = 50
        env = BouncingBallsEnv(
            min_num_balls=min_num_balls,
            max_num_balls=max_num_balls,
            side_length=side_length,
            obs_space_mode="flat",
            is_continuing=False,
        )
        state = env.reset()
        # Separately check that the length of the state is constant wrt the number of balls in the scene.
        assert len(state) == max_num_balls * 4 + 4 + 2
        assert env.observation_space.contains(state)


class TestBouncingBallsEnvRewardModel:
    @pytest.mark.parametrize("states,actions,next_states,expected_rewards", [
        (get_noncolliding_states(2, 2), get_noop_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_colliding_states(2, 2), get_noop_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_noncolliding_states(2, 2), get_noop_actions(2), get_colliding_states(2, 2), torch.zeros((2, 1))),
        (get_noncolliding_states(2, 2), get_one_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_goal_reached_states(2, 2), get_noop_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_noncolliding_states(2, 2), get_noop_actions(2), get_goal_reached_states(2, 2), torch.ones((2, 1))),
        (get_noncolliding_states(1, 2), get_noop_actions(1), get_noncolliding_states(1, 2), torch.zeros((1, 1))),
    ])
    def test_ground_truth_reward_model(self, states, actions, next_states, expected_rewards):
        model = BouncingBallsEnvRewardModel.make_ground_truth_reward_model(
            dummy_bouncing_balls_obs_space(),
            dummy_bouncing_balls_act_space(),
        )
        actual_rewards = model.reward(states, actions, next_states, None)
        assert torch.equal(actual_rewards, expected_rewards)

    @pytest.mark.parametrize("states,actions,next_states,expected_rewards", [
        (get_noncolliding_states(2, 2), get_noop_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_colliding_states(2, 2), get_noop_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_noncolliding_states(2, 2), get_noop_actions(2), get_colliding_states(2, 2), -1 * torch.ones((2, 1))),
    ])
    def test_obstacle_collision_rewards(self, states, actions, next_states, expected_rewards):
        model = BouncingBallsEnvRewardModel(
            dummy_bouncing_balls_obs_space(),
            dummy_bouncing_balls_act_space(),
            reaching_goal_reward=0,
            obstacle_collision_reward=-1,
            action_magnitude_reward=0,
            distance_from_goal_reward=0,
            shaping_toward_goal_factor=0,
            shaping_discount=1,
        )
        actual_rewards = model.reward(states, actions, next_states, None)
        assert torch.equal(actual_rewards, expected_rewards)

    @pytest.mark.parametrize("states,actions,next_states,expected_rewards", [
        (get_noncolliding_states(2, 2), get_noop_actions(2), get_noncolliding_states(2, 2), torch.zeros((2, 1))),
        (get_colliding_states(2, 2), get_one_actions(2), get_noncolliding_states(2, 2), -2 / np.sqrt(2) * torch.ones(
            (2, 1))),
        (get_colliding_states(2, 2), 3 * get_one_actions(2), get_noncolliding_states(
            2, 2), -6 / np.sqrt(2) * torch.ones((2, 1))),
    ])
    def test_action_magnitude_rewards(self, states, actions, next_states, expected_rewards):
        model = BouncingBallsEnvRewardModel(
            dummy_bouncing_balls_obs_space(),
            dummy_bouncing_balls_act_space(),
            reaching_goal_reward=0,
            obstacle_collision_reward=0,
            action_magnitude_reward=-1,
            distance_from_goal_reward=0,
            shaping_toward_goal_factor=0,
            shaping_discount=1,
        )
        actual_rewards = model.reward(states, actions, next_states, None)
        assert torch.equal(actual_rewards, expected_rewards)

    @pytest.mark.parametrize("states,actions,next_states,expected_rewards", [
        (
            get_dist_from_goal_states(2, 2, 10),
            get_noop_actions(2),
            get_noncolliding_states(2, 2, 10),
            -10 * torch.ones((2, 1)),
        ),
        (
            get_dist_from_goal_states(2, 2, 1),
            get_noop_actions(2),
            get_noncolliding_states(2, 2),
            -1 * torch.ones((2, 1)),
        ),
        (
            get_dist_from_goal_states(2, 2, 0),
            get_noop_actions(2),
            get_dist_from_goal_states(2, 2, 10),
            torch.zeros((2, 1)),
        ),
    ])
    def test_distance_from_goal_rewards(self, states, actions, next_states, expected_rewards):
        model = BouncingBallsEnvRewardModel(
            dummy_bouncing_balls_obs_space(),
            dummy_bouncing_balls_act_space(),
            reaching_goal_reward=0,
            obstacle_collision_reward=0,
            action_magnitude_reward=0,
            distance_from_goal_reward=1,
            shaping_toward_goal_factor=0,
            shaping_discount=1,
        )
        actual_rewards = model.reward(states, actions, next_states, None)
        assert torch.equal(actual_rewards, expected_rewards)


class TestBouncingBallsEnvFeasibilityRewardWrapper:
    @pytest.mark.parametrize("dt,num_steps,random_seed", [
        (0.1, 10, 1),
        (0.2, 100, 2),
        (0.5, 100, 3),
    ])
    def test_transitions_from_env_are_feasible(self, dt, num_steps, random_seed):
        np.random.seed(random_seed)
        env = BouncingBallsEnv(dt=dt)

        reward_model = BouncingBallsEnvFeasibilityRewardWrapper(
            ConstantRewardModel(1),
            ConstantRewardModel(2),
            env.dt,
            max_position_error=10,
        )

        policy = RandomPolicy(env.action_space)
        batch = sample_rollout(env, policy, num_steps)

        feasibility = reward_model.transitions_are_feasible(
            torch.tensor(batch[SampleBatch.OBS]),
            torch.tensor(batch[SampleBatch.ACTIONS]),
            torch.tensor(batch[SampleBatch.NEXT_OBS]),
        )
        assert torch.all(feasibility)
        expected_num_steps = len(batch)
        assert feasibility.shape == (expected_num_steps, 1)

    def test_transitions_are_feasible_invalid_ego(self):
        env = BouncingBallsEnv()
        x = env.reset()
        a = env.action_space.sample()
        nx = x.copy()
        nx[:4] += 100

        reward_model = BouncingBallsEnvFeasibilityRewardWrapper(
            ConstantRewardModel(1),
            ConstantRewardModel(2),
            env.dt,
            max_position_error=10,
            check_ego=True,
        )

        feasibility = reward_model.transitions_are_feasible(
            torch.tensor([x]),
            torch.tensor([a]),
            torch.tensor([nx]),
        )
        assert feasibility[0] == False

    def test_transitions_are_feasible_invalid_ado(self):
        env = BouncingBallsEnv()
        x = env.reset()
        a = env.action_space.sample()
        nx = x.copy()
        nx[4:-2] += 100

        reward_model = BouncingBallsEnvFeasibilityRewardWrapper(
            ConstantRewardModel(1),
            ConstantRewardModel(2),
            env.dt,
            max_position_error=10,
        )

        feasibility = reward_model.transitions_are_feasible(
            torch.tensor([x]),
            torch.tensor([a]),
            torch.tensor([nx]),
        )
        assert feasibility[0] == False

    def test_reward(self):
        env = BouncingBallsEnv()
        x = env.reset()
        a = env.action_space.sample()
        nx, _, _, _ = env.step(a)

        reward_model = BouncingBallsEnvFeasibilityRewardWrapper(
            ConstantRewardModel(1),
            ConstantRewardModel(2),
            env.dt,
            max_position_error=10,
            check_ego=True,
        )

        rewards = reward_model.reward(
            torch.tensor([x]),
            torch.tensor([a]),
            torch.tensor([nx]),
            None,
        )
        assert torch.eq(rewards, 1)

        nx[:4] += 100
        rewards = reward_model.reward(
            torch.tensor([x]),
            torch.tensor([a]),
            torch.tensor([nx]),
            None,
        )
        assert torch.eq(rewards, 2)
