import numpy as np
import torch

class MeatballsGenerator:
    def __init__(self, size_x = 64, size_y = 64, size_z = 64):
        self.size_x = size_x
        self.size_y = size_y
        self.size_z = size_z
        self.space = np.indices((self.size_x, self.size_y, self.size_z)).astype(np.float32)
        self.space[0] = self.space[0] / (self.size_x - 1)
        self.space[1] = self.space[1] / (self.size_y - 1)
        self.space[2] = self.space[2] / (self.size_z - 1)

    def get_distance_sum(self, points):
        eps = 1e-8
        dist_sum = np.zeros((self.size_x, self.size_y, self.size_z)).astype(np.float32)
        for point in points:
            dist = np.linalg.norm(self.space - point.reshape(3,1,1,1), axis=0)
            dist_sum = dist_sum + 1 / (dist+eps)
        return dist_sum

    def meatballs_binmask(self, points, min_volume=0.0, max_volume=1.0):
        dist_sum = self.get_distance_sum(points)
        q = np.random.uniform(1 - max_volume, 1 - min_volume)
        quantile = np.quantile(dist_sum, q)
        return dist_sum > quantile

    def generate_random_meatball(self, min_volume=0.0, max_volume=1.0):
        starting_points_num = int(np.random.randint(4, 7, size=1)[0])
        sp_x = np.random.randint(0, self.size_x, size=starting_points_num) / (self.size_x - 1)
        sp_y = np.random.randint(0, self.size_y, size=starting_points_num) / (self.size_y - 1)
        sp_z = np.random.randint(0, self.size_z, size=starting_points_num) / (self.size_z - 1)
        starting_points = np.stack((sp_x, sp_y, sp_z), axis=1)
        mask = self.meatballs_binmask(
            starting_points, min_volume=min_volume, max_volume=max_volume
        )
        return mask

