import numpy as np
import pytest
import torch

from offline_rl.rewards.evaluation.reward_collection import RewardCollection


class TestRewardCollection:
    @pytest.mark.parametrize("rewards,expected_validity", [
        (RewardCollection(), True),
        (RewardCollection({"a": torch.zeros(10)}), True),
        (RewardCollection({
            "a": torch.zeros(10),
            "b": torch.zeros(10),
        }), True),
        (RewardCollection({
            "a": torch.zeros(10),
            "b": torch.zeros(5),
        }), False),
    ])
    def test_is_valid(self, rewards, expected_validity):
        assert rewards.is_valid() == expected_validity

    def test_append_valid(self):
        rewards_ab = RewardCollection({"a": torch.zeros(10), "b": torch.zeros(10)})
        rewards_ab.append(rewards_ab)
        assert "a" in rewards_ab
        assert "b" in rewards_ab
        assert rewards_ab.is_valid()
        assert len(rewards_ab["a"]) == 20

    def test_append_invalid(self):
        rewards_ab = RewardCollection({"a": torch.zeros(10), "b": torch.zeros(10)})
        rewards_ac = RewardCollection({"a": torch.zeros(10), "c": torch.zeros(10)})
        try:
            rewards_ab.append(rewards_ac)
        except AssertionError:
            assert True
        else:
            assert False

    def test_append_empty(self):
        rewards_ab = RewardCollection({"a": torch.zeros(10), "b": torch.zeros(10)})
        rewards_empty = RewardCollection()
        rewards = rewards_empty.append(rewards_ab)
        assert "a" in rewards
        assert "b" in rewards
        assert rewards.is_valid()
        assert len(rewards["a"]) == 10

    def test_mean_1d(self):
        rewards = RewardCollection({"a": 1 + torch.arange(10).to(float), "b": torch.arange(10).to(float)})
        mean_rewards = rewards.mean()
        assert "a" in mean_rewards
        assert "b" in mean_rewards
        assert float(mean_rewards["a"]) == np.mean(range(10)) + 1
        assert float(mean_rewards["b"]) == np.mean(range(10))

    def test_mean_2d(self):
        a = 1 + torch.arange(10).to(float).reshape(2, 5)
        b = torch.arange(10).to(float).reshape(2, 5)
        rewards = RewardCollection({"a": a, "b": b})
        mean_rewards = rewards.mean(dim=1)
        assert "a" in mean_rewards
        assert "b" in mean_rewards
        assert torch.equal(mean_rewards["a"], torch.mean(torch.arange(10).reshape(2, 5).to(float) + 1, 1))
        assert torch.equal(mean_rewards["b"], torch.mean(torch.arange(10).reshape(2, 5).to(float), 1))

    def test_mul(self):
        rewards = RewardCollection({"a": torch.ones(10), "b": 2 * torch.ones(10)})
        weights = torch.arange(10).to(float)
        mult_rewards = rewards * weights
        assert "a" in mult_rewards
        assert "b" in mult_rewards
        assert torch.equal(mult_rewards["a"], torch.arange(10).to(float))
        assert torch.equal(mult_rewards["b"], torch.arange(10).to(float) * 2)

    def test_add(self):
        rewards_1 = RewardCollection({"a": torch.ones(10), "b": 2 * torch.ones(10)})
        rewards_2 = RewardCollection({"a": -torch.ones(10), "b": 2 * torch.ones(10)})
        add_rewards = rewards_1 + rewards_2
        assert "a" in add_rewards
        assert "b" in add_rewards
        assert torch.equal(add_rewards["a"], torch.zeros(10))
        assert torch.equal(add_rewards["b"], torch.ones(10) * 4)

    def test_reshape(self):
        rewards = RewardCollection({"a": torch.ones(10), "b": 2 * torch.ones(10)})
        reshaped_rewards = rewards.reshape(2, 5)
        assert "a" in reshaped_rewards
        assert "b" in reshaped_rewards
        assert reshaped_rewards["a"].shape == (2, 5)
        assert reshaped_rewards["b"].shape == (2, 5)
