import random
import torch

from Hash.KWiseHash import KWiseHash

INF = 1e10
EPSILON = 0.00001

class PrioritySampling:
    def __init__(self, hash_function, sample_size=0):
        self.tau = INF
        self.hash_function = hash_function
        self.sample_size = sample_size
        self.stored_items = []

    def __len__(self):
        return torch.count_nonzero(self.stored_items)
        # return self.stored_items.shape[0] * self.stored_items.shape[1]

    def __mul__(self, other):  
        result = torch.zeros((self.stored_items.shape[1], other.stored_items.shape[1]),dtype=torch.float64).to('cuda')
        matches = self.sampled_indices.unsqueeze(1) == other.sampled_indices.unsqueeze(0) 
        matching_indices = matches.nonzero(as_tuple=True)
        result_indices = list(zip(matching_indices[0].tolist(), matching_indices[1].tolist()))

        for i,j in result_indices:      
            result += torch.outer(self.stored_items[i, :], other.stored_items[j, :]) / \
                      (min(1.0,
                           self.tau * self.sampling_ratio[i] + EPSILON,
                           other.tau * other.sampling_ratio[j] + EPSILON))
        return result

    def __str__(self):
        return str(self.stored_items)

    def hash(self, data_set, sampling_ratio, pi=None):
        self.sampling_ratio = sampling_ratio
        priority = self.hash_function / self.sampling_ratio

        if self.sample_size < data_set.shape[0]:
            sampled_priority, self.sampled_indices = torch.topk(priority, k=self.sample_size + 1, largest=False)
            self.tau = sampled_priority[-1]
        else:
            self.sampled_indices = torch.arange(data_set.shape[0]).to('cuda')
        self.sampling_ratio = self.sampling_ratio[self.sampled_indices]
        
        self.stored_items = data_set[self.sampled_indices, :]