import os
import sys
import enum
import tensorflow as tf

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_MODULE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_MODULE_DIR, "tf_ops"))

from IEProtLibModule import knn

class MCKnn:
    """Class to represent a k-nearest neighbor operation.

    Attributes:
        pcSamples_ (MCPointCloud): Samples point cloud.
        grid_  (MCGrid): Regular grid data structure.
        radii_ (float tensor d): Radii used to select the neighbors.
        samplesNeighRanges_ (int tensor n): End of the ranges for each sample.
        neighbors_ (int tensor mx2): Indices of the neighbor point and the sample
            for each neighbor.
        pdf_ (float tensor m): PDF value for each neighbor.
    """

    def __init__(self, pGrid, pKNeighobrs, pNearest = True, pRadii=None, pPCSample = None):
        """Constructor.

        Args:
            pGrid  (MCGrid): Regular grid data structure.
            pKNeighobrs (int): Number of k-nearest neighbors.
            pNearest (bool): Boolean that indicates if we select the nearest or the 
                farthests neighbors.
            pRadii (float): Radius in which we have to search for the neighboring points.
            pPCSample (MCPointCloud): Samples point cloud. If None, the sorted
                points from the grid will be used.
        """

        #Save the attributes.
        if pPCSample is None:
            self.pcSamples_ = MCPointCloud(pGrid.sortedPts_, \
                pGrid.sortedBatchIds_, pGrid.batchSize_)
        else:
            self.pcSamples_ = pPCSample
        
        if pRadii is None:
            self.radii_ = pGrid.cellSizes_
        else:
            self.radii_ = pRadii

        self.grid_ = pGrid
        self.numKNeighs_ = pKNeighobrs
        self.nearest_ = pNearest

        #Compute the knn.
        self.knnIndices_ = knn(self.grid_, self.pcSamples_, 
            self.radii_, self.numKNeighs_, self.nearest_)


    def get_knn_pts_coords(self, pKIndex = None):
        """Method to get the point coordinates of the neighbors.

        Args:
            pKIndices (int): Index of the k neighbors we want to 
                obtain.
        Return:
            (float tensor nxd): Output point coordinates of the
            neighbors.
            (bool tensor n): Valid coordinates mask.
        """
        if pKIndex is None:
            mask = tf.math.greater_equal(self.knnIndices_, 0)
            pts = tf.gather(self.grid_.sortedPts_, tf.reshape(self.knnIndices_, [-1]))
            return tf.reshape(pts, [-1, tf.shape(self.knnIndices_)[1], tf.shape(self.grid_.sortedPts_)[1]]), mask
        else:
            mask = tf.math.greater_equal(self.knnIndices_[:, pKIndex], 0)
            return tf.gather(self.grid_.sortedPts_, self.knnIndices_[:, pKIndex]), mask