import numpy
from numpy.testing import assert_array_equal
import unittest

from .one_hot import OneHotSketcher


def fake_hash(s):
    """
    In order to get deterministic results during testing, we use a fake hash
    function that avoids collisions under carefully controlled collisions.

    This function will avoid collisions under the following assumptions:
    - The output of the hash function is interpreted the way OneHotSketcher
      interprets it (i.e. last bit used as a sign, and the rest mod num buckets
      is the hash bucket).
    - The input to the hash function is either an identity or has the form
      identity:value (that is, identity followed by ":" followed by the value
      encoded as a string).
    - Identities are ints in the range 0..99 encoded as strings. The identities
      2*i and 2*i+1 are never both used, for any i. (They would hash to the same
      bucket and get opposite signs.)
    - There are at least 100*num_categories hash buckets if the input has the
      form identity:value, or at least 100 if the input is just an identity.
    """
    parts = s.split(":")
    if len(parts) == 1:
        return int(s)
    else:
        identity, value = parts
        return int(identity) + 100 * int(value)

def one_hot(dim, i):
    result = numpy.zeros(dim)
    result[i] = 1
    return result

class OneHotSketcherTest(unittest.TestCase):
    """
    Test OneHotSketcher with epsilon = None, and with identities that hash to
    different buckets, so that the sketch should answer queries perfectly.
    """
    def setUp(self):
        # Note that fake_hash makes assumptions about how the sketcher will use
        # it, so it may need to be updated if the implementation changes.
        self._sketcher = OneHotSketcher(hash_function = fake_hash)

    def test_single_category_one_identity(self):
        """
        A sketch with only one category (with value 0) allowed. We add just one
        identity to it.
        """
        identity = "26"
        another_identity = "28"
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = [identity], num_categories = 1,
            values = [0], num_buckets = 400)
        self.assertEqual(
            self._sketcher.estimate_membership(sketch, identity, 0),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(sketch, another_identity, 0),
            0)
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, (another_identity, identity), 1),
            numpy.array(([0], [1])))

    def test_two_categories_one_identity(self):
        identity = "26"
        another_identity = "28"
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = [identity], num_categories = 2,
            values = [0], num_buckets = 400)
        self.assertEqual(
            self._sketcher.estimate_membership(sketch, identity, 0),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(sketch, another_identity, 0),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(sketch, identity, 1),
            0)
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, (identity, another_identity), 2),
            numpy.array(((1, 0), (0, 0))))

    def test_difference_sketch(self):
        identity = "29"
        another_identity = "44"
        unused_identity = "31"
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = [identity, another_identity],
            num_categories = 2, values = [1, 0], num_buckets = 400,
            difference_sketch = True)
        # Sketcher.estimate_training_labels_from_difference_sketch maps the
        # value 0 to 1 and 1 to -1 (leaving 0 to mean "not present"), so
        # we expect identity and another_identity to get values -1 and 1.
        assert_array_equal(
            self._sketcher.estimate_training_labels_from_difference_sketch(
                sketch, (identity, unused_identity, another_identity)),
            numpy.array((-1, 0, 1)))

    def test_recover_one_value(self):
        # Sketch one person.
        identity = "26"
        another_identity = "12"
        value = 42
        another_value = 101
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = [identity], num_categories = 128,
            values = [value], num_buckets = 40_000)
        self.assertEqual(
            self._sketcher.estimate_membership(sketch, identity, value),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identity, another_value),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, another_identity, value),
            0)
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, (identity, another_identity), 128),
            numpy.array((one_hot(128, value), numpy.zeros(128))))

    def test_recover_two_rows(self):
        identities = ["3", "7"]
        another_identity = "12"
        values = [42, 20]
        another_value = 31
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = identities, num_categories = 128,
            values = values, num_buckets = 40_000)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[0], values[0]),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[1], values[1]),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[0], values[1]),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[1], values[0]),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[0], another_value),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, another_identity, values[0]),
            0)
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, [another_identity] + identities, 128),
            numpy.array(
                [numpy.zeros(128)] +
                [one_hot(128, value) for value in values]))

    def test_recover_two_values_with_same_id(self):
        """
        Generally, we assume each identity appears only once, but certain
        operations, including estimate_membership, should work okay when
        identities can be repeated.
        """
        identity = "29"
        another_identity = "12"
        values = [1200, 209]
        another_value = 31
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = [identity, identity],
            num_categories = 2048, values = values, num_buckets = 300_000)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identity, values[0]),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identity, values[1]),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identity, another_value),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, another_identity, values[0]),
            0)
        expected_weights = numpy.zeros(2048)
        for v in values: expected_weights[v] = 1
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, (identity, another_identity), 2048),
            numpy.array(
                (expected_weights, numpy.zeros(2048))))

    def test_recover_two_ids_with_same_value(self):
        identities = "42", "38"
        another_identity = "1"
        value = 49
        another_value = 31
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = identities, num_categories = 128,
            values = [value, value], num_buckets = 40_000)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[0], value),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[1], value),
            1)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, identities[0], another_value),
            0)
        self.assertEqual(
            self._sketcher.estimate_membership(
                sketch, another_identity, value),
           0)
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, identities + (another_identity,), 128),
            numpy.array(
                [one_hot(128, value)] * 2 + [numpy.zeros(128)]))

    def test_adjust_training_weights_from_collision(self):
        """
        If there are hash collisions, the training weights should be adjusted
        accordingly. As a hack, we force a "hash" collision by simply repeating
        the same identity more than once.
        """
        identities = ["3", "7"]
        another_identity = "12"
        values = [12, 7]
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = identities, num_categories = 16,
            values = values, num_buckets = 3200)
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, ["3", "7", "3", another_identity], 16),
            numpy.array((
                one_hot(16, 12) / 2,
                one_hot(16, 7),
                one_hot(16, 12) / 2,
                numpy.zeros(16))))

    def test_difference_sketch_adjust_estimates_from_collision(self):
        """
        Like test_adjust_training_weights_from_collision, but with a difference
        sketch.
        """
        identities = ["3", "7"]
        another_identity = "12"
        values = [0, 1]
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = identities, num_categories = 2,
            values = values, num_buckets = 100, difference_sketch = True)
        # Recall that Sketcher.estimate_training_labels_from_difference_sketch
        # maps the value 0 to 1 and 1 to -1.
        assert_array_equal(
            self._sketcher.estimate_training_labels_from_difference_sketch(
                sketch, ["3", "7", "3", another_identity]),
            numpy.array((1/2, -1, 1/2, 0)))

    def test_adjust_training_weights_by_category_from_collision(self):
        """
        An even hackier version of test_adjust_training_weights_from_collision.
        Here we test that even for the same ID, different re-weightings can be
        applied to different categories (i.e. training labels). We can't use the
        same trick of just repeating the identity, so we rely on the details of
        how fake_hash works and how OneHotSketch calls it.
        """
        sketch = self._sketcher.sketch_values(
            epsilon = None, identities = ("0", "100", "100"),
            num_categories = 6, values = (1, 5, 0), num_buckets = 2000)
        expected_id_0_weights = numpy.zeros(6)
        expected_id_0_weights[1] = 1/2  # collides with (id 100, value 0)
        expected_id_100_weights = numpy.zeros(6)
        expected_id_100_weights[0] = 1/2  # collides with (id 0, value 1)
        expected_id_100_weights[5] = 1
        assert_array_equal(
            self._sketcher.estimate_training_weights(
                sketch, ("0", "100"), 6),
            numpy.array((expected_id_0_weights, expected_id_100_weights)))

if __name__ == "__main__":
    unittest.main()
