import gym
import numpy as np
import pytest
import torch

from offline_rl.rewards.evaluation.distances import compute_distance_between_reward_pairs
from offline_rl.rewards.evaluation.epic import EPIC
from offline_rl.rewards.evaluation.model_collection import ModelCollection
from offline_rl.rewards.evaluation.transition_sampler import FixedDistributionTransitionSampler
from offline_rl.rewards.tabular_reward_model import TabularRewardModel
from offline_rl.utils.testing.rewards import ConstantRewardModel, TabularPotentialShaping


class TestEPIC:
    @pytest.mark.parametrize("num_states,state_dim,num_transitions_per_state,action_dim", [
        (100, 5, 50, 2),
        (25, 5, 50, 2),
        (1, 5, 50, 2),
        (25, 5, 1, 2),
        (1, 5, 1, 2),
    ])
    def test_compute_mean_rewards(self, num_states, state_dim, num_transitions_per_state, action_dim):
        models = ModelCollection({
            "one": ConstantRewardModel(1.0),
            "two": ConstantRewardModel(2.0),
        })
        states = torch.zeros((num_states, state_dim))
        transition_sampler = FixedDistributionTransitionSampler(
            torch.ones((num_transitions_per_state, action_dim)),
            torch.ones((num_transitions_per_state, state_dim)),
        )

        distance = EPIC()
        rewards = distance.compute_mean_rewards(models, states, transition_sampler)

        assert "one" in rewards
        assert "two" in rewards
        assert rewards["one"].shape == (num_states, )
        assert rewards["two"].shape == (num_states, )
        assert torch.allclose(rewards["one"], torch.ones_like(rewards["one"]))
        assert torch.allclose(rewards["two"], torch.ones_like(rewards["two"]) * 2)

    @pytest.mark.parametrize("num_states,state_dim,num_transitions_per_state,action_dim", [
        (100, 5, 50, 2),
        (25, 5, 50, 2),
        (1, 5, 50, 2),
        (25, 5, 1, 2),
        (1, 5, 1, 2),
    ])
    def test_compute_canonical_rewards_simple(self, num_states, state_dim, num_transitions_per_state, action_dim):
        models = ModelCollection({
            "one": ConstantRewardModel(1.0),
            "two": ConstantRewardModel(2.0),
        })

        states = torch.zeros((num_states, state_dim))
        actions = torch.zeros((num_states, state_dim))
        next_states = torch.zeros((num_states, state_dim))
        terminals = None

        transition_sampler = FixedDistributionTransitionSampler(
            torch.ones((num_transitions_per_state, action_dim)),
            torch.ones((num_transitions_per_state, state_dim)),
        )

        distance = EPIC()
        rewards = distance.compute_canonical_rewards(
            models,
            states,
            actions,
            next_states,
            terminals,
            transition_sampler,
            discount=1,
            total_mean_mode="per_batch_approximation",
            should_normalize_scale=True,
        )

        assert "one" in rewards
        assert "two" in rewards
        assert rewards["one"].shape == (num_states, )
        assert rewards["two"].shape == (num_states, )
        assert torch.allclose(rewards["one"], torch.zeros_like(rewards["one"]))
        assert torch.allclose(rewards["two"], torch.zeros_like(rewards["two"]))

    @pytest.mark.parametrize("state_dim,action_dim,num_samples,num_transitions_per_state,discount", [
        (2, 2, 1000, 100, 1.0),
        (2, 2, 1000, 100, 0.9),
        (5, 2, 1000, 100, 1.0),
        (2, 5, 1000, 100, 1.0),
    ])
    def test_compute_canonical_rewards_general_tabular(
            self,
            state_dim,
            action_dim,
            num_samples,
            num_transitions_per_state,
            discount,
            tolerance=1e-4,
    ):
        # TODO(redacted): Switch to using the hypothesis package for this test instead of setting the random seed.
        torch.manual_seed(0)
        random_reward_1 = TabularRewardModel(torch.rand(state_dim, action_dim, state_dim))
        random_reward_2 = TabularRewardModel(torch.rand(state_dim, action_dim, state_dim))
        shaped_random_reward_1 = TabularPotentialShaping(random_reward_1, torch.rand(state_dim), discount)
        models = ModelCollection({
            "random_1": random_reward_1,
            "random_2": random_reward_2,
            "shaped_random_1": shaped_random_reward_1,
        })

        # Get the transitions used for the reward calculation.
        # This assumes a transition model that uniformly transitions between all states regardless of the action.
        state_space = gym.spaces.Discrete(state_dim)
        states = torch.LongTensor([[state_space.sample()] for _ in range(num_samples)])
        next_states = torch.LongTensor([[state_space.sample()] for _ in range(num_samples)])
        action_space = gym.spaces.Discrete(action_dim)
        actions = torch.LongTensor([[action_space.sample()] for _ in range(num_samples)])
        terminals = None

        # Get the states and actions used for the mean reward calculation.
        transition_sampler = FixedDistributionTransitionSampler(
            torch.LongTensor([[action_space.sample()] for _ in range(num_transitions_per_state)]),
            torch.LongTensor([[state_space.sample()] for _ in range(num_transitions_per_state)]),
        )

        distance = EPIC()
        rewards = distance.compute_canonical_rewards(
            models,
            states,
            actions,
            next_states,
            terminals,
            transition_sampler,
            discount=discount,
        )

        matrix = compute_distance_between_reward_pairs(rewards)
        assert matrix.distance_between("random_1", "shaped_random_1") < tolerance
        assert matrix.distance_between("random_1", "random_2") > 0.25
        assert matrix.distance_between("shaped_random_1", "random_2") > 0.25

    @pytest.mark.parametrize("state_dim,action_dim,num_samples,num_transitions_per_state", [
        (2, 2, 1000, 100),
        (2, 2, 1, 100),
        (2, 2, 1000, 1),
        (2, 2, 1, 1),
    ])
    def test_compute_batch_indices(
            self,
            state_dim,
            action_dim,
            num_samples,
            num_transitions_per_state,
    ):
        states = torch.zeros((num_samples, state_dim))
        transition_sampler = FixedDistributionTransitionSampler(
            torch.zeros((num_transitions_per_state, action_dim)),
            torch.zeros((num_transitions_per_state, state_dim)),
        )
        distance = EPIC()
        indices = distance.compute_batch_indices(states, transition_sampler)
        concatenated = np.concatenate([np.arange(start, end) for (start, end) in zip(indices, indices[1:])])
        concatenated = concatenated[:num_samples]
        np.testing.assert_array_equal(concatenated, np.arange(num_samples))

    def test_compute_conditional_per_state_out_of_distribution_total_mean_rewards(self):
        states = torch.tensor([
            [0],
            [1],
        ])
        next_states = torch.tensor([
            [2],
            [3],
        ])

        transition_actions = torch.zeros((3, 1))
        transition_next_states = torch.tensor([
            [4],
            [5],
            [6],
        ])
        transition_sampler = FixedDistributionTransitionSampler(transition_actions, transition_next_states)
        distance = EPIC()

        models = ModelCollection({
            "one": ConstantRewardModel(1.0),
            "two": ConstantRewardModel(2.0),
        })

        total_mean_rewards = distance._compute_conditional_per_state_out_of_distribution_total_mean_rewards(
            models,
            states,
            next_states,
            transition_sampler,
        )

        assert total_mean_rewards["one"].shape == (2, )
        assert total_mean_rewards["two"].shape == (2, )
