import numpy as np
from scipy import sparse
import torch 

class KWiseHash:
    def __init__(self, dimension_num=1, k_wise=4, prime=2147483587):
        self.dimension_num = dimension_num
        self.k_wise = k_wise
        self.prime = prime

    def hash(self, vec, seed):
        """
        The standard MinHash algorithm for binary sets
        A. Z. Broder, M. Charikar, A. M. Frieze, and M. Mitzenmacher, "Min-wise Independent Permutations",
        in STOC, 1998, pp. 518-529
        """
        np.random.seed(seed)
        hash_parameters = np.random.randint(1, self.prime, (self.dimension_num, self.k_wise))
        nonzero_index = sparse.find(vec != 0)[1]

        hash_kwise = 0
        for exp in range(self.k_wise):
            hash_kwise += np.dot(np.transpose(np.array([nonzero_index]) ** exp),
                                 np.array([np.transpose(hash_parameters[:, exp])]))
        hash_kwise = np.mod(hash_kwise, self.prime) / self.prime

        if self.dimension_num == 1:
            hashes = hash_kwise.reshape(hash_kwise.shape[0], )
            values = vec[nonzero_index]
        else:
            hashes = np.min(hash_kwise, axis=0)
            positions = np.argmin(hash_kwise, axis=0)
            values = vec[nonzero_index[positions]]
        return torch.tensor(hashes, dtype=torch.float64).to('cuda'), torch.tensor(values, dtype=torch.float64).to('cuda')