import random
import torch

from Hash.KWiseHash import KWiseHash

INF = 1e10
EPSILON = 0.00001

class ThresholdSampling:
    def __init__(self, hash_function, tau):
        self.tau = INF
        self.hash_function = hash_function
        self.tau = tau
        self.stored_items = []

    def __len__(self):
        return torch.count_nonzero(self.stored_items)
        # return self.stored_items.shape[0] * self.stored_items.shape[1]

    def __mul__(self, other):  
        result = torch.zeros((self.stored_items.shape[1], other.stored_items.shape[1]), dtype=torch.float64).to('cuda')
        matches = self.sampled_indices.unsqueeze(1) == other.sampled_indices.unsqueeze(0) 
        matching_indices = matches.nonzero(as_tuple=True)
        result_indices = list(zip(matching_indices[0].tolist(), matching_indices[1].tolist()))

        for i,j in result_indices:      
            result += torch.outer(self.stored_items[i, :], other.stored_items[j, :]) / \
                      (min(1.0,
                           self.tau * self.sampling_ratio[i]+EPSILON,
                           other.tau * other.sampling_ratio[j]+EPSILON))
        return result

    def __str__(self):
        return str(self.stored_items)

    def hash(self, data_set, sampling_ratio, pi=None):
        self.sampling_ratio = sampling_ratio
        priority = self.hash_function / self.sampling_ratio

        mask = priority <= self.tau
        self.sampled_indices = torch.nonzero(mask, as_tuple=False).squeeze()
        self.sampling_ratio = self.sampling_ratio[self.sampled_indices]
        
        self.stored_items = data_set[self.sampled_indices, :]