import numpy as np
import utils

class CountSketch():
    def __init__(self, n, W, L):
        '''
        Initialize a CountSketch table with W rows and L columns
        '''
        self.n = n
        self.W = W
        self.L = L
        self.rowHashes = [utils.HashFns(self.n) for i in range(self.L)]
        self.signHashes = [utils.HashFns(self.n) for i in range(self.L)]
        self.sketch = np.zeros((self.W, self.L))

        self.sketchMats = [np.zeros((self.W, self.n)) for j in range(self.L)]
        for i in range(n):
            for j in range(self.L):
                row = self.rowHashes[j].cwHash(i) % self.W
                sign = 1 - 2 * (self.signHashes[j].cwHash(i) % 2)
                self.sketchMats[j][row, i] = sign

    def __ids(self, x):
        rows = np.array([hash.cwHash(x) % self.W for hash in self.rowHashes])
        cols = np.arange(self.L)
        return rows, cols

    def __signs(self, x):
        return 1 - (2 * np.array([hash.cwHash(x) % 2 for hash in self.signHashes]))

    def vectorUpdate(self, freqs):
        '''Update sketch with a vector (frequency histogram)'''
        assert(len(freqs) == self.n)
        for j in range(self.L):
            self.sketch[:,j] += (self.sketchMats[j] @ freqs)

    def update(self, x, d=1):
        '''Update sketch with d copies of x'''
        ids = self.__ids(x)
        signs = self.__signs(x)
        self.sketch[ids[0], ids[1]] += (signs * d)

    def getSketch(self):
        return self.sketch

    def estimate(self, x):
        ids = self.__ids(x)
        signs = self.__signs(x)
        return np.median(self.sketch[ids[0], ids[1]] * signs)

if __name__ == "__main__":
    U = 10000 # universe size
    W = 10
    L = 3
    rowHashes = [utils.HashFns(U) for i in range(L)]
    signHashes = [utils.HashFns(U) for i in range(L)]
    cs = CountSketch(W, L, rowHashes, signHashes)
    cs.update(10)
    cs.update(20)
    print(cs.estimate(10))
    print(cs.getSketch().transpose())
