import numpy as np
import utils
from collections import defaultdict
import copy

class MisraGries():
    def __init__(self, k):
        '''
        Initialize Misra-Gries with k counters
        '''
        self.k = k
        self.counters = defaultdict(int)

    def vectorUpdate(self, freqs):
        '''Update sketch with a vector (frequency histogram)'''
        nonzeros = np.nonzero(freqs)[0]
        for i in nonzeros:
            freq = freqs[i]
            for __ in range(freq):
                self.update(i)

    def update(self, x):
        assert len(self.counters) <= self.k
        if len(self.counters) < self.k:
            self.counters[x] += 1
        elif x in self.counters:
            self.counters[x] += 1
        else:
            # decrement and remove zero counters
            zero_keys = []
            for y in self.counters:
                self.counters[y] -= 1
                assert self.counters[y] >= 0
                if self.counters[y] == 0:
                    zero_keys.append(y)
            for y in zero_keys:
                del self.counters[y]

    def estimate(self, x):
        if x in self.counters:
            return self.counters[x]
        else:
            return 0

class CountSketch():
    def __init__(self, n, W, L):
        '''
        Initialize a CountSketch table with W rows and L columns
        '''
        self.n = n
        self.W = W
        self.L = L
        self.rowHashes = [utils.HashFns(self.n) for i in range(self.L)]
        self.signHashes = [utils.HashFns(self.n) for i in range(self.L)]
        self.sketch = np.zeros((self.W, self.L))

        self.sketchMats = [np.zeros((self.W, self.n)) for j in range(self.L)]
        for i in range(n):
            for j in range(self.L):
                row = self.rowHashes[j].cwHash(i) % self.W
                sign = 1 - 2 * (self.signHashes[j].cwHash(i) % 2)
                self.sketchMats[j][row, i] = sign

    def __ids(self, x):
        rows = np.array([hash.cwHash(x) % self.W for hash in self.rowHashes])
        cols = np.arange(self.L)
        return rows, cols

    def __signs(self, x):
        return 1 - (2 * np.array([hash.cwHash(x) % 2 for hash in self.signHashes]))

    def vectorUpdate(self, freqs):
        '''Update sketch with a vector (frequency histogram)'''
        assert(len(freqs) == self.n)
        for j in range(self.L):
            self.sketch[:,j] += (self.sketchMats[j] @ freqs)

    def update(self, x, d=1):
        '''Update sketch with d copies of x'''
        ids = self.__ids(x)
        signs = self.__signs(x)
        self.sketch[ids[0], ids[1]] += (signs * d)

    def getSketch(self):
        return self.sketch

    def estimate(self, x):
        ids = self.__ids(x)
        signs = self.__signs(x)
        return np.median(self.sketch[ids[0], ids[1]] * signs)

if __name__ == "__main__":
    U = 10000 # universe size
    W = 10
    L = 3
    rowHashes = [utils.HashFns(U) for i in range(L)]
    signHashes = [utils.HashFns(U) for i in range(L)]
    cs = CountSketch(W, L, rowHashes, signHashes)
    cs.update(10)
    cs.update(20)
    print(cs.estimate(10))
    print(cs.getSketch().transpose())
