import tensorflow as tf
import sys
import os

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BUILD_DIR = os.path.join(BASE_DIR, "build")

sys.path.append(BASE_DIR)

IEProtLib_module = tf.load_op_library(os.path.join(BUILD_DIR, 'IEProtLib.so'))

def compute_keys(pPointCloud, pAABB, pNumCells, pCellSize):
    return IEProtLib_module.compute_keys(
        pPointCloud.pts_,
        pPointCloud.batchIds_, 
        pAABB.aabbMin_/pCellSize,
        pNumCells,
        tf.math.reciprocal(pCellSize))
tf.NoGradient('ComputeKeys')


def build_grid_ds(pKeys, pNumCells, pBatchSize):
    return IEProtLib_module.build_grid_ds(
        pKeys,
        pNumCells,
        pNumCells,
        pBatchSize)
tf.NoGradient('BuildGridDs')


def find_neighbors(pGrid, pPCSamples, pRadii, pMaxNeighbors):    
    return IEProtLib_module.find_neighbors(
        pPCSamples.pts_,
        pPCSamples.batchIds_,
        pGrid.sortedPts_,
        pGrid.sortedKeys_,
        pGrid.fastDS_,
        pGrid.numCells_,
        pGrid.aabb_.aabbMin_/pGrid.cellSizes_,
        tf.math.reciprocal(pGrid.cellSizes_),
        tf.math.reciprocal(pRadii),
        pMaxNeighbors)
tf.NoGradient('FindNeighbors')


def pooling(pNeighborhood, pPoolMode):
    return IEProtLib_module.pooling(
        pNeighborhood.grid_.sortedPts_,
        pNeighborhood.grid_.sortedBatchIds_,
        pNeighborhood.grid_.sortedKeys_,
        pNeighborhood.grid_.numCells_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_,
        pPoolMode)
tf.NoGradient('Pooling')


def knn(pGrid, pPCSamples, pRadii, pNumNeighs, pNearest):
    curNeighs = pNumNeighs
    if not(pNearest):
        curNeighs *= -1
    return IEProtLib_module.knn(
        pPCSamples.pts_,
        pPCSamples.batchIds_,
        pGrid.sortedPts_,
        pGrid.sortedKeys_,
        pGrid.fastDS_,
        pGrid.numCells_,
        pGrid.aabb_.aabbMin_/pGrid.cellSizes_,
        tf.math.reciprocal(pGrid.cellSizes_),
        tf.math.reciprocal(pRadii),
        curNeighs)
tf.NoGradient('KnnGrid')

def compute_pdf(pNeighborhood, pBandwidth, pMode):  
    return IEProtLib_module.compute_pdf(
        pNeighborhood.grid_.sortedPts_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_, 
        tf.math.reciprocal(pBandwidth),
        tf.math.reciprocal(pNeighborhood.radii_),
        pMode)
tf.NoGradient('ComputePdf')

def compute_pdf_with_pt_grads(pNeighborhood, pBandwidth, pMode): 
    return IEProtLib_module.compute_pdf_with_pt_grads(
        pNeighborhood.grid_.sortedPts_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_, 
        tf.math.reciprocal(pBandwidth),
        tf.math.reciprocal(pNeighborhood.radii_),
        pMode)
@tf.RegisterGradient("ComputePdfWithPtGrads")
def _compute_pdf_grad(op, *grads):
    inPtsGrad = IEProtLib_module.compute_pdf_pt_grads(
        op.inputs[0],
        op.inputs[1],
        op.inputs[2],
        op.inputs[3],
        op.inputs[4], 
        grads[0], 
        op.get_attr("mode"))
    return [inPtsGrad, None, None, None, None]


def compute_topo_dist(pGraph, pNeighborhood, pMaxDistance, pConstEdge = False):
    intConstEdge = 0
    if pConstEdge:
        intConstEdge = 1
    return IEProtLib_module.compute_topo_dist(
        pNeighborhood.pcSamples_.pts_,
        pNeighborhood.originalNeighIds_,
        pGraph.neighbors_,
        pGraph.nodeStartIndexs_,
        pMaxDistance,
        intConstEdge)
tf.NoGradient('ComputeTopoDist')

def find_neighbors_topo(pGraph, pPts, pMaxDistance, pConstEdge = False):
    intConstEdge = 0
    if pConstEdge:
        intConstEdge = 1
    neighs = IEProtLib_module.find_neighbors_topo(
        pPts,
        pGraph.neighbors_,
        pGraph.nodeStartIndexs_,
        pMaxDistance,
        intConstEdge)
    
    nodeIds = tf.tile(tf.reshape(tf.range(tf.shape(neighs)[0]), 
        [-1,1]), [1, tf.shape(neighs)[1]])
    reshNeighs = tf.reshape(neighs, [-1])
    reshNodeIds = tf.reshape(nodeIds, [-1])
    maskValidNodes = tf.greater_equal(reshNeighs, 0)
    reshNeighs = tf.reshape(tf.boolean_mask(reshNeighs, maskValidNodes), [-1,1])
    reshNodeIds = tf.reshape(tf.boolean_mask(reshNodeIds, maskValidNodes), [-1,1])

    return tf.concat([reshNeighs, reshNodeIds], axis = -1)

tf.NoGradient('FindNeighborsTopo')


def compute_smooth_weights(pNeighborhood, pRadius):  
    return IEProtLib_module.compute_smooth_w(
        pNeighborhood.grid_.sortedPts_,
        pNeighborhood.pcSamples_.pts_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_, 
        tf.math.reciprocal(pRadius))
tf.NoGradient('ComputeSmoothW')


def compute_smooth_weights_with_pt_grads(pNeighborhood, pRadius):  
    return IEProtLib_module.compute_smooth_w_with_pt_grads(
        pNeighborhood.grid_.sortedPts_,
        pNeighborhood.pcSamples_.pts_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_, 
        tf.math.reciprocal(pRadius))
@tf.RegisterGradient("ComputeSmoothWWithPtGrads")
def _compute_smooth_w_grad(op, *grads):
    inPtsGrad, inSampleGrad = IEProtLib_module.compute_smooth_w_pt_grads(
        op.inputs[0],
        op.inputs[1],
        op.inputs[2],
        op.inputs[3],
        op.inputs[4], 
        grads[0])
    return [inPtsGrad, inSampleGrad, None, None, None]


def compute_protein_pooling(pGraph):
    return IEProtLib_module.protein_pooling(
        pGraph.neighbors_,
        pGraph.nodeStartIndexs_)
tf.NoGradient('ProteinPooling')


def compute_graph_aggregation(pGraph, pFeatures, pNormalize):
    if pNormalize:
        inNorm = 1
    else:
        inNorm = 0
    return IEProtLib_module.graph_aggregation(
        pFeatures, pGraph.neighbors_,
        pGraph.nodeStartIndexs_, inNorm)
@tf.RegisterGradient("GraphAggregation")
def _compute_graph_aggregation_grad(op, *grads):
    outGrads = IEProtLib_module.graph_aggregation_grads(
        grads[0], op.inputs[1], op.inputs[2],
        op.get_attr("normalize"))
    return [outGrads, None, None]


def collapse_edges(pEdgeSortedIds, pEdgeIds, pStartNodeIds):
    return IEProtLib_module.collapse_edges(
        pEdgeSortedIds, pEdgeIds, pStartNodeIds)
tf.NoGradient('CollapseEdges')


def basis_proj(pNeighborhood, pInFeatures,
        pBasis, pBasisType, pPtGrads):  
    if pNeighborhood.smoothW_ is None:
        curPDF = pNeighborhood.pdf_
    else:
        curPDF = pNeighborhood.pdf_ * tf.math.reciprocal(pNeighborhood.smoothW_)
    return IEProtLib_module.basis_proj(
        pNeighborhood.grid_.sortedPts_,
        pInFeatures,
        pNeighborhood.pcSamples_.pts_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_, 
        tf.math.reciprocal(pNeighborhood.radii_),
        curPDF,
        pBasis, 
        pBasisType,
        pPtGrads)
@tf.RegisterGradient("BasisProj")
def _basis_proj_grad(op, *grads):
    if op.get_attr("pt_grads"):
        featGrads, basisGrads, pointGrads, sampleGrads, pdfGrads = \
            IEProtLib_module.basis_proj_grads_with_pt_grads(
            op.inputs[0], op.inputs[1], op.inputs[2],
            op.inputs[3], op.inputs[4], op.inputs[5],
            op.inputs[6], op.inputs[7], 
            grads[0], op.get_attr("basis_type"))
    else:
        pointGrads = None
        sampleGrads = None
        pdfGrads = None
        featGrads, basisGrads = IEProtLib_module.basis_proj_grads(
            op.inputs[0], op.inputs[1], op.inputs[2],
            op.inputs[3], op.inputs[4], op.inputs[5],
            op.inputs[6], op.inputs[7], 
            grads[0], op.get_attr("basis_type"))
    return [pointGrads, featGrads, sampleGrads, None, None, 
        None, pdfGrads, basisGrads]


def basis_proj_bilateral(pNeighborhood, pNeighVals, pInFeatures,
        pBasis, pBasisType, pPtGrads):  
    if pNeighborhood.smoothW_ is None:
        curPDF = pNeighborhood.pdf_
    else:
        curPDF = pNeighborhood.pdf_ * tf.math.reciprocal(pNeighborhood.smoothW_)
    return IEProtLib_module.basis_proj_bil(
        pNeighborhood.grid_.sortedPts_,
        pInFeatures,
        pNeighborhood.pcSamples_.pts_,
        pNeighborhood.neighbors_,
        pNeighborhood.samplesNeighRanges_, 
        tf.math.reciprocal(pNeighborhood.radii_),
        curPDF,
        pNeighVals,
        pBasis, 
        pBasisType,
        pPtGrads)
@tf.RegisterGradient("BasisProjBil")
def _basis_proj_bilateral_grad(op, *grads):
    if op.get_attr("pt_grads"):
        featGrads, basisGrads, pointGrads, sampleGrads, pdfGrads, neighGrads = \
            IEProtLib_module.basis_proj_bil_grads_with_pt_grads(
            op.inputs[0], op.inputs[1], op.inputs[2],
            op.inputs[3], op.inputs[4], op.inputs[5],
            op.inputs[6], op.inputs[7], op.inputs[8], 
            grads[0], op.get_attr("basis_type"))
    else:
        pointGrads = None
        sampleGrads = None
        pdfGrads = None
        neighGrads = None
        featGrads, basisGrads = IEProtLib_module.basis_proj_bil_grads(
            op.inputs[0], op.inputs[1], op.inputs[2],
            op.inputs[3], op.inputs[4], op.inputs[5],
            op.inputs[6], op.inputs[7], op.inputs[8], 
            grads[0], op.get_attr("basis_type"))
    return [pointGrads, featGrads, sampleGrads, None, None, 
        None, pdfGrads, neighGrads, basisGrads]
