import unittest

import einops
import jax
import jax.numpy as jnp

from tabular_mvdrl.kernels import energy_distance
from tabular_mvdrl.utils.discrete_distributions import (
    DiscreteDistribution,
    SquaredMMDMetric,
    SupremalMetric,
    Wasserstein2Metric,
)


class TestDiscreteDistributions(unittest.TestCase):
    def setUp(self):
        self.rng = jax.random.PRNGKey(0)
        self.n_atoms = 5
        self.n_dims = 2
        self.batch_size = 4

    def test_vmap(self):
        batch_size = self.batch_size
        n_atoms = self.n_atoms
        n_dims = self.n_dims
        locs = jax.random.normal(self.rng, shape=(batch_size, n_atoms, n_dims))
        probs = jnp.ones((batch_size, n_atoms)) / n_atoms
        dist = jax.vmap(DiscreteDistribution)(locs=locs, probs=probs)
        self.assertListEqual(list(dist.locs.shape), [batch_size, n_atoms, n_dims])
        self.assertListEqual(list(dist.probs.shape), [batch_size, n_atoms])

    def test_empirical_vmap(self):
        batch_size = self.batch_size
        n_atoms = self.n_atoms
        n_dims = self.n_dims
        locs = jax.random.normal(self.rng, shape=(batch_size, n_atoms, n_dims))
        dist = jax.vmap(DiscreteDistribution.empirical_from)(locs)
        self.assertListEqual(list(dist.locs.shape), [batch_size, n_atoms, n_dims])
        self.assertListEqual(list(dist.probs.shape), [batch_size, n_atoms])

    def test_pushforward_linear(self):
        locs = jax.random.normal(self.rng, shape=(self.n_atoms, self.n_dims))
        probs = jax.random.uniform(self.rng, shape=(self.n_atoms,))
        dist = DiscreteDistribution(locs=locs, probs=probs)
        reward = jax.random.normal(self.rng, shape=(self.n_dims,))
        pushforward_dist = dist.pushforward_linear(reward)
        expected_locs = einops.einsum(reward, dist.locs, "r, n r -> n")
        self.assertTrue(jnp.all(expected_locs == pushforward_dist.locs))
        self.assertTrue(jnp.all(pushforward_dist.probs == dist.probs))

    def test_pushforward_linear_batched_dist(self):
        locs = jax.random.normal(
            self.rng, shape=(self.batch_size, self.n_atoms, self.n_dims)
        )
        probs = jax.random.uniform(self.rng, shape=(self.batch_size, self.n_atoms))
        dist_batch = jax.vmap(DiscreteDistribution)(locs=locs, probs=probs)
        reward = jax.random.normal(self.rng, shape=(self.n_dims,))
        pushforward_dist_batch = jax.vmap(
            DiscreteDistribution.pushforward_linear, in_axes=(0, None)
        )(dist_batch, reward)
        for i in range(self.batch_size):
            dist = jax.tree_util.tree_map(lambda x: x[i], dist_batch)
            pushforward_dist = jax.tree_util.tree_map(
                lambda x: x[i], pushforward_dist_batch
            )
            expected_locs = einops.einsum(reward, dist.locs, "r, n r -> n")
            self.assertTrue(jnp.all(expected_locs == pushforward_dist.locs))
            self.assertTrue(jnp.all(pushforward_dist.probs == dist.probs))

    def test_pushforward_linear_batched_reward(self):
        locs = jax.random.normal(self.rng, shape=(self.n_atoms, self.n_dims))
        probs = jax.random.uniform(self.rng, shape=(self.n_atoms,))
        dist = DiscreteDistribution(locs=locs, probs=probs)
        reward_batch = jax.random.normal(
            self.rng,
            shape=(
                self.batch_size,
                self.n_dims,
            ),
        )
        pushforward_dist_batch = jax.vmap(
            DiscreteDistribution.pushforward_linear, in_axes=(None, 0)
        )(dist, reward_batch)
        for i in range(self.batch_size):
            reward = reward_batch[i, :]
            pushforward_dist = jax.tree_util.tree_map(
                lambda x: x[i], pushforward_dist_batch
            )
            expected_locs = einops.einsum(reward, dist.locs, "r, n r -> n")
            self.assertTrue(jnp.all(expected_locs == pushforward_dist.locs))
            self.assertTrue(jnp.all(pushforward_dist.probs == dist.probs))

    def test_pushforward_linear_batched_reward_and_dist(self):
        reward_batch_size = 6
        locs = jax.random.normal(
            self.rng, shape=(self.batch_size, self.n_atoms, self.n_dims)
        )
        probs = jax.random.uniform(self.rng, shape=(self.batch_size, self.n_atoms))
        dist_batch = jax.vmap(DiscreteDistribution)(locs=locs, probs=probs)
        reward_batch = jax.random.normal(
            self.rng,
            shape=(
                reward_batch_size,
                self.n_dims,
            ),
        )
        pushforward_dist_batch = jax.vmap(
            jax.vmap(DiscreteDistribution.pushforward_linear, in_axes=(0, None)),
            in_axes=(None, 0),
        )(dist_batch, reward_batch)
        self.assertListEqual(
            list(pushforward_dist_batch.locs.shape),
            [reward_batch_size, self.batch_size, self.n_atoms],
        )
        self.assertListEqual(
            list(pushforward_dist_batch.probs.shape),
            [reward_batch_size, self.batch_size, self.n_atoms],
        )
        for reward_i in range(reward_batch_size):
            for dist_i in range(self.batch_size):
                reward = reward_batch[reward_i, :]
                dist = jax.tree_util.tree_map(lambda x: x[dist_i], dist_batch)
                pushforward_dist = jax.tree_util.tree_map(
                    lambda x: x[reward_i, dist_i], pushforward_dist_batch
                )
                expected_locs = einops.einsum(reward, dist.locs, "r ..., n r -> n ...")
                self.assertTrue(jnp.all(expected_locs == pushforward_dist.locs))
                self.assertTrue(jnp.all(pushforward_dist.probs == dist.probs))

    def test_mmd2_metric(self):
        mmd = SquaredMMDMetric(energy_distance)
        locs = jax.random.normal(self.rng, shape=(self.n_atoms, self.n_dims))
        probs = jax.nn.softmax(jax.random.uniform(self.rng, shape=(self.n_atoms,)))
        d1 = DiscreteDistribution.empirical_from(locs)
        d2 = DiscreteDistribution(locs=locs, probs=probs)
        self.assertEqual(mmd(d2, d2), 0.0)
        self.assertNotAlmostEqual(mmd(d1, d2), 0.0)

    def test_wasserstein2_metric(self):
        metric = Wasserstein2Metric(epsilon=1e-4)
        locs = jax.random.normal(self.rng, shape=(self.n_atoms, self.n_dims))
        probs = jax.nn.softmax(jax.random.uniform(self.rng, shape=(self.n_atoms,)))
        d1 = DiscreteDistribution.empirical_from(locs)
        d2 = DiscreteDistribution(locs=locs, probs=probs)
        self.assertAlmostEqual(metric(d2, d2), 0.0, places=3)
        self.assertNotAlmostEqual(metric(d1, d2), 0.0, places=3)

    def test_supremal_mmd2_metric(self):
        metric = SupremalMetric(base_metric=SquaredMMDMetric(energy_distance))
        locs = jax.random.normal(
            self.rng, shape=(self.batch_size, self.n_atoms, self.n_dims)
        )
        probs = jax.nn.softmax(
            jax.random.uniform(self.rng, shape=(self.batch_size, self.n_atoms)), axis=-1
        )
        d1 = jax.vmap(DiscreteDistribution.empirical_from)(locs)
        d2 = jax.vmap(DiscreteDistribution)(locs=locs, probs=probs)
        self.assertEqual(metric(d2, d2), 0.0)
        self.assertNotAlmostEqual(metric(d1, d2), 0.0)

    def test_supremal_wasserstein2_metric(self):
        metric = SupremalMetric(base_metric=Wasserstein2Metric(epsilon=1e-4))
        locs = jax.random.normal(
            self.rng, shape=(self.batch_size, self.n_atoms, self.n_dims)
        )
        probs = jax.nn.softmax(
            jax.random.uniform(self.rng, shape=(self.batch_size, self.n_atoms)), axis=-1
        )
        d1 = jax.vmap(DiscreteDistribution.empirical_from)(locs)
        d2 = jax.vmap(DiscreteDistribution)(locs=locs, probs=probs)
        self.assertAlmostEqual(metric(d2, d2), 0.0, places=3)
        self.assertNotAlmostEqual(metric(d1, d2), 0.0, places=3)


if __name__ == "__main__":
    unittest.main()
