import numpy

from ..sketch import Sketcher
from ..random import discrete_laplace

class OneHotSketcher(Sketcher):
    def __init__(self, **kwds):
        """
        Parameters are all forwarded to the parent class (Sketcher).
        """
        super().__init__(**kwds)

    def sketch_values(self, epsilon, identities, num_categories, values,
                      num_buckets, difference_sketch = False):
        """
        This function requires one additional parameter compared to the base
        class: num_buckets is the number of hash buckets to use.
        """
        if difference_sketch:
            assert num_categories == 2
        sketch = numpy.zeros(num_buckets, dtype=numpy.int64)
        for identity, value in zip(identities, values):
            if difference_sketch:
                bucket, sign = self.__bucket_and_sign(
                    num_buckets, identity)
                if value:
                    sign = -sign
            else:
                bucket, sign = self.__bucket_and_sign(
                    num_buckets, identity, value)
            sketch[bucket] += sign
        if epsilon is not None:
            sketch = sketch + discrete_laplace(self._rng,
                inv_scale = epsilon,
                size = num_buckets)
        return sketch

    def estimate_membership(self, sketch, identity, value):
        hash_bucket, sign = self.__bucket_and_sign(
            sketch.shape[0], identity, value)
        return sign * sketch[hash_bucket]

    def estimate_training_weights(self, sketch, identities, num_categories):
        num_buckets = sketch.shape[0]
        buckets = numpy.zeros(
            (len(identities), num_categories), dtype=numpy.int64)
        signs = numpy.zeros(
            (len(identities), num_categories), dtype=numpy.int64)
        for i, identity in enumerate(identities):
            for value in range(num_categories):
                buckets[i, value], signs[i, value] = self.__bucket_and_sign(
                    num_buckets, identity, value)
        bucket_counts = numpy.bincount(
            buckets.flatten(), minlength = num_buckets)

        def weight(identity_index, value):
            bucket = buckets[identity_index, value]
            sketch_value = sketch[bucket]
            if sketch_value == 0:
                return 0
            return (
                signs[identity_index, value]
                * (sketch_value / abs(sketch_value))
                / bucket_counts[bucket])

        return numpy.array(
            tuple(
                tuple(
                    weight(identity_index, value)
                    for value in range(num_categories))
                for identity_index in range(len(identities))))

    def estimate_training_labels_from_difference_sketch(
            self, sketch, identities):
        num_buckets = sketch.shape[0]
        buckets = numpy.zeros(len(identities), dtype=numpy.int64)
        signs = numpy.zeros(len(identities), dtype=numpy.int64)
        for i, identity in enumerate(identities):
            buckets[i], signs[i] = self.__bucket_and_sign(num_buckets, identity)
        bucket_counts = numpy.bincount(buckets, minlength = num_buckets)

        def estimate(identity_index):
            bucket = buckets[identity_index]
            sketch_value = sketch[bucket]
            if sketch_value == 0:
                return 0
            return (
                signs[identity_index]
                * (sketch_value / abs(sketch_value))
                / bucket_counts[bucket])

        return numpy.array(
            tuple(
                estimate(identity_index)
                for identity_index in range(len(identities))))

    def __bucket_and_sign(self, num_buckets, identity, value = None):
        """
        identity should be a string. value should be a nonnegative integer, or
        None if we're generating a difference sketch. Returns a pair (i, sign)
        where sign is -1 or 1 and i is a positive integer less than dimension.
        """
        hash_input = (
            identity if value is None else f"{identity}:{value}")
        hash_int = self._hash_function(hash_input)
        sign = 2 * (hash_int % 2) - 1
        return (hash_int // 2) % num_buckets, sign
