import gym
import pytest
from ray.rllib.policy.sample_batch import SampleBatch
import torch

from offline_rl.rewards.learning.preference_based_reward_model import PreferenceBasedRewardModel
from offline_rl.rewards.learning.reward_model_networks import FullyConnectedRewardModel
from offline_rl.utils.testing.rewards import RandomRewardModel


def sample_segments_batch_from_space(space, batch_size, segment_length):
    segments_batch = torch.tensor([space.sample() for _ in range(batch_size * segment_length * 2)])
    segments_batch = segments_batch.reshape(batch_size, 2, segment_length, -1)
    return segments_batch


def get_random_segments_batch(obs_space, act_space, batch_size, segment_length):
    obs = sample_segments_batch_from_space(obs_space, batch_size, segment_length)
    act = sample_segments_batch_from_space(act_space, batch_size, segment_length)
    next_obs = sample_segments_batch_from_space(obs_space, batch_size, segment_length)
    return SampleBatch({
        SampleBatch.OBS: obs,
        SampleBatch.ACTIONS: act,
        SampleBatch.NEXT_OBS: next_obs,
    })


class TestPreferenceBasedRewardModel:
    @pytest.mark.parametrize("batch_size,segment_length", [
        (10, 5),
        (1, 5),
        (10, 1),
    ])
    def test_compute_target_preferences(self, batch_size, segment_length):
        obs_space = gym.spaces.Box(low=-1, high=1, shape=(4, ))
        act_space = gym.spaces.Box(low=-1, high=1, shape=(2, ))
        batch = get_random_segments_batch(obs_space, act_space, batch_size, segment_length)

        submodel_kwargs = dict(
            obs_space=obs_space,
            act_space=act_space,
            hidden_sizes=[8, 8],
        )
        model = PreferenceBasedRewardModel(
            FullyConnectedRewardModel,
            submodel_kwargs,
            target_model=RandomRewardModel(),
        )

        preferences = model._compute_target_preferences(batch)
        assert preferences.shape == (batch_size, )

    @pytest.mark.parametrize("batch_size,segment_length", [
        (10, 5),
        (1, 5),
        (10, 1),
    ])
    def test_predict_preferences(self, batch_size, segment_length):
        obs_space = gym.spaces.Box(low=-1, high=1, shape=(4, ))
        act_space = gym.spaces.Box(low=-1, high=1, shape=(2, ))
        batch = get_random_segments_batch(obs_space, act_space, batch_size, segment_length)

        submodel_kwargs = dict(
            obs_space=obs_space,
            act_space=act_space,
            hidden_sizes=[8, 8],
        )
        model = PreferenceBasedRewardModel(
            FullyConnectedRewardModel,
            submodel_kwargs,
            target_model=RandomRewardModel(),
        )

        pref_log_probs, rewards = model._predict_preferences(batch)
        assert pref_log_probs.shape == (batch_size, 2)
        assert rewards.shape == (batch_size, 2, segment_length)

    @pytest.mark.parametrize("batch_size,segment_length", [
        (10, 5),
        (1, 5),
        (10, 1),
    ])
    def test_generic_step(self, batch_size, segment_length):
        obs_space = gym.spaces.Box(low=-1, high=1, shape=(4, ))
        act_space = gym.spaces.Box(low=-1, high=1, shape=(2, ))
        batch = get_random_segments_batch(obs_space, act_space, batch_size, segment_length)

        submodel_kwargs = dict(
            obs_space=obs_space,
            act_space=act_space,
            hidden_sizes=[8, 8],
        )
        model = PreferenceBasedRewardModel(
            FullyConnectedRewardModel,
            submodel_kwargs,
            target_model=RandomRewardModel(),
        )

        loss = model.generic_step(batch, 0, "train")
        assert loss.ndim == 0
        assert loss > 0
