import os
import sys
import math
import enum
import numpy as np
import tensorflow as tf

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_MODULE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_MODULE_DIR, "tf_ops"))

from IEProtLibModule import find_neighbors_topo

from IEProtLib.pc import MCAABB
from IEProtLib.pc import MCPointCloud
from IEProtLib.pc import MCGrid
from IEProtLib.pc import MCNeighborhood
from IEProtLib.pc import MCKDEMode
from IEProtLib.pc import MCHProjLRELUFactory
from IEProtLib.graph import MCGraphConvBuilder
from IEProtLib.graph import MCGraph
from IEProtLib.tf_utils import MC_BN_AF_DO, spectral_norm
from IEProtLib.tf_utils.MCConvBuilder import MCConvBuilder
from IEProtLib.mol import MCMolConv
from IEProtLib.mol import MCMolecule


class MCMolConvBuilder(MCConvBuilder):
    """Class to create convolutions.

    Attributes:
        pcConvBuilder_ (MCPCConvBuilder): Convolution builder for points.
        graphConvBuilder_ (MCGraphConvBuilder): Convolution builder for graphs.
    """

    def __init__(self, 
        pWeightRegCollection = "weight_regularization_collection"):
        """Constructor.

        Args:
            pWeightRegCollection (string): Weight regularization collection name.       
        """
        super(MCMolConvBuilder, self).__init__(pWeightRegCollection, None)

        self.graphConvBuilder_ = MCGraphConvBuilder(pWeightRegCollection, None)
        self.molConvFactory_ = MCMolConv()
        self.hProjFactory_ = MCHProjLRELUFactory()


    def create_prot_pooling(self, pInFeatures, pProtein, pLevel, pBNAFDO):
        """Method to create a protein pooling operation.

        Args:
            pInFeatures (float tensor nxf): Input features.
            pProtein (MCProtein): Protein.
            pLevel (int): Level we want to pool.
            pBNAFDO (MCBNAFDO): MCBNAFDO object.
        Returns:
            (float tensor n'xf): Pooled features to the level pLevel+1.
        """

        if pProtein.poolType_[pLevel-1] == "GRA":

            poolFeatures = self.graphConvBuilder_.create_graph_aggregation(
                pInFeatures = pInFeatures,
                pGraph = pProtein.molObjects_[pLevel-1].graph_, 
                pNormalize = True, 
                pSpectralApprox = False)
            poolFeatures = tf.gather(poolFeatures, pProtein.poolIds_[pLevel-1])
        
        elif pProtein.poolType_[pLevel-1] == "AVG":

            poolFeatures = tf.unsorted_segment_mean(pInFeatures,
                pProtein.poolIds_[pLevel-1],
                tf.shape(pProtein.molObjects_[pLevel].batchIds_)[0])

        elif pProtein.poolType_[pLevel-1].startswith("GRAPH_DROP"):

            maskValueBool, poolFeatures, newGraph = self.graphConvBuilder_.create_graph_node_pooling(
                "Graph_drop_pooling_"+str(pLevel), 
                pProtein.molObjects_[pLevel-1].batchIds_,
                pProtein.molObjects_[pLevel-1].graph_, 
                pInFeatures, 
                pProtein.molObjects_[pLevel-1].batchSize_,
                0.5, pBNAFDO)

            newPos = tf.boolean_mask(pProtein.molObjects_[pLevel-1].pc_.pts_, maskValueBool)
            newBatchIds = tf.boolean_mask(pProtein.molObjects_[pLevel-1].batchIds_, maskValueBool)

            if pProtein.molObjects_[pLevel-1].graph2_ is None:
                newGraph2 = MCGraph(None, None)
            else:
                newGraph2 = pProtein.molObjects_[pLevel-1].graph2_.pool_graph_drop_nodes(
                    maskValueBool, tf.shape(newPos)[0])

            pProtein.molObjects_[pLevel] = MCMolecule(
                newPos, 
                newGraph.neighbors_, 
                newGraph.nodeStartIndexs_, 
                newBatchIds, 
                pProtein.molObjects_[pLevel-1].batchSize_,
                newGraph2.neighbors_, 
                newGraph2.nodeStartIndexs_)

            if pProtein.poolType_[pLevel-1] == "GRAPH_DROP_AMINO":
                pProtein.poolIds_[pLevel]  = tf.boolean_mask(pProtein.atomAminoIds_, maskValueBool)

        elif pProtein.poolType_[pLevel-1].startswith("GRAPH_EDGE"):
            
            newIndices, poolFeatures, newGraph = self.graphConvBuilder_.create_graph_edge_pooling(
                "Graph_edge_pooling_"+str(pLevel), 
                pProtein.molObjects_[pLevel-1].graph_, 
                pInFeatures, 
                pBNAFDO)

            newPos = tf.unsorted_segment_mean(pProtein.molObjects_[pLevel-1].pc_.pts_,
                newIndices, tf.shape(poolFeatures)[0])
            newBatchIds = tf.unsorted_segment_max(pProtein.molObjects_[pLevel-1].batchIds_,
                newIndices, tf.shape(poolFeatures)[0])

            if pProtein.molObjects_[pLevel-1].graph2_ is None:
                newGraph2 = MCGraph(None, None)
            else:
                newGraph2 = pProtein.molObjects_[pLevel-1].graph2_.pool_graph_collapse_edges(
                    newIndices, tf.shape(newPos)[0])

            pProtein.molObjects_[pLevel] = MCMolecule(
                newPos, 
                newGraph.neighbors_, 
                newGraph.nodeStartIndexs_, 
                newBatchIds, 
                pProtein.molObjects_[pLevel-1].batchSize_,
                newGraph2.neighbors_, 
                newGraph2.nodeStartIndexs_)

            if pProtein.poolType_[pLevel-1] == "GRAPH_EDGE_AMINO":
                pProtein.poolIds_[pLevel] = tf.unsorted_segment_max(
                    pProtein.atomAminoIds_, newIndices, tf.shape(poolFeatures)[0])

        return poolFeatures


    def create_molconv_resnet_blocks(self, 
        pMolecule,
        pInFeatures, 
        pNumBlocks, 
        pOutNumFeatures,
        pRadii, 
        pBNAFDO,
        pSmoothWeights = False,
        pKDEMode = MCKDEMode.constant, 
        pPDFBandwidth = 0.2,
        pConvName = None,
        pConvType = "molconv",
        pNumBasis = 32,
        pNeighType = "spatial",
        pAtomDropOut = 0.05):
        """Method to create a set of molconv resnet bottleneck blocks.

        Args:
            pMolecule (MCMolecule): Input molecule.
            pInFeatures (float tensor nxf): Input features.
            pNumBlocks (int): Number of grouping blocks.
            pOutNumFeatures (int): Number of output features.
            pRadii (float): Radius of the convolution.
            pBNAFDO (MC_BRN_AF_DO): Layer to apply batch renorm, activation function,
                and drop out.
            pSmoothWeights (bool): Boolean that indicates if we smooth the boundaries
                of the weights in the 3d space.
            pKDEMode (MCKDEMode): Mode used to determine the bandwidth in the KDE.
            pPDFBandwidth (float): Constant bandwidth used in the KDE in mode constant.
            pConvName (string): Name of the convolution. If empty, a unique id is created.
            pConvType (string): Type of convolution: molconv, mcconv, topo, graph.
            pNumBasis (int): Number of basis vectors used.
            pAtomDropOut (float): Dropout used during training to randomly discard a
        Return:
            (tensor nxf): Output features.
        """

        #Create the radii tensor.
        radiiTensor = tf.convert_to_tensor(
            np.full((3), pRadii, dtype=np.float32), np.float32)

         #Create the badnwidth tensor.
        bwTensor = tf.convert_to_tensor(
            np.full((3), pPDFBandwidth, dtype=np.float32), np.float32)

        #Get the point cloud.
        inPointCloud = pMolecule.pc_

        #Compute the bounding box.
        aabb = MCAABB(inPointCloud)

        #Compute the grid key.
        grid = MCGrid(inPointCloud, aabb, radiiTensor)

        #Create the neighborhood.
        if pNeighType == "spatial":
            neigh = MCNeighborhood(grid, radiiTensor, inPointCloud, 0)
            neigh.compute_pdf(bwTensor, pKDEMode, False)
            if pSmoothWeights:
                neigh.compute_smooth_weights(False)
        elif pNeighType == "graph1":
            auxNeighs = find_neighbors_topo(pMolecule.graph_, pMolecule.pc_.pts_, 4.1, True)
            startIds = tf.unsorted_segment_sum(
                tf.ones_like(auxNeighs[:, 1]),
                auxNeighs[:, 1],
                tf.shape(pMolecule.batchIds_)[0])
            startIds = tf.cumsum(startIds)
            neigh = MCNeighborhood(grid, radiiTensor, inPointCloud, 0)
            neigh.originalNeighIds_ = auxNeighs
            neigh.samplesNeighRanges_ = startIds
            neigh.neighbors_ = auxNeighs
            neigh.grid_.sortedPts_ = pMolecule.pc_.pts_
            neigh.compute_pdf(bwTensor, pKDEMode, False)
        elif pNeighType == "graph2":
            auxNeighs = find_neighbors_topo(pMolecule.graph2_, pMolecule.pc_.pts_, 4.1, True)
            startIds = tf.unsorted_segment_sum(
                tf.ones_like(auxNeighs[:, 1]),
                auxNeighs[:, 1],
                tf.shape(pMolecule.batchIds_)[0])
            startIds = tf.cumsum(startIds)
            neigh = MCNeighborhood(grid, radiiTensor, inPointCloud, 0)
            neigh.originalNeighIds_ = auxNeighs
            neigh.samplesNeighRanges_ = startIds
            neigh.neighbors_ = auxNeighs
            neigh.grid_.sortedPts_ = pMolecule.pc_.pts_
            neigh.compute_pdf(bwTensor, pKDEMode, False)


        #Compute the topological distance.
        topoDists = pMolecule.compute_topo_distance(neigh, pRadii*2.0)

        #Create the convolution name if is not user defined.
        curConvName = pConvName
        if curConvName is None:
            curConvName = hash((neigh, pInFeatures.name))
        
        #Create the different bottleneck blocks.
        curInFeatures = pInFeatures
        for curBlock in range(pNumBlocks):

            #Define a name for the bottleneck block.
            bnName = curConvName+"_resnetb_"+str(curBlock)

            #Save the input features of the block
            blockInFeatures = pBNAFDO(curInFeatures, bnName+"_In_B",
                pApplyBN = True, pApplyNoise = False, pApplyAF = False, pApplyDO = False)

            curInFeatures = pBNAFDO(blockInFeatures, bnName+"_In_AD",
                pApplyBN = False, pApplyNoise = True, pApplyAF = True, pApplyDO = True)

            #Create the first convolution of the block.
            curInFeatures = self.create_1x1_convolution(curInFeatures, pOutNumFeatures//4,
                bnName+"_Conv_1x1_1")
            curInFeatures = pBNAFDO(curInFeatures, bnName+"_Conv_1_BAD",
                pApplyBN = True, pApplyNoise = True, pApplyAF = True, pApplyDO = False)

            #Atom dropout
            tfDORate = tf.cond(pBNAFDO.isTraining_, 
                    true_fn = lambda: pAtomDropOut,
                    false_fn = lambda: 0.0)
            curInFeatures = tf.nn.dropout(curInFeatures, 1.0 - tfDORate, 
                    name=bnName+"_Atom_DO", noise_shape=[tf.shape(curInFeatures)[0], 1])

            #Create the second convolution of the block.
            if pConvType == "molconv":
                
                curInFeatures = tf.gather(curInFeatures, grid.sortedIndices_)
                curInFeatures = self.molConvFactory_.create_convolution( 
                    pConvName = bnName+"_Conv_2",
                    pNeighborhood = neigh, 
                    pTopoDists = topoDists,
                    pFeatures = curInFeatures,
                    pNumOutFeatures = pOutNumFeatures//4,
                    pWeightRegCollection = self.weightRegCollection_,
                    pNumBasis = pNumBasis)

            elif pConvType == "molconv_dist":
                
                curInFeatures = tf.gather(curInFeatures, grid.sortedIndices_)
                curInFeatures = self.molConvFactory_.create_convolution( 
                    pConvName = bnName+"_Conv_2",
                    pNeighborhood = neigh, 
                    pTopoDists = topoDists,
                    pFeatures = curInFeatures,
                    pNumOutFeatures = pOutNumFeatures//4,
                    pWeightRegCollection = self.weightRegCollection_,
                    pNumBasis = pNumBasis,
                    pTopoOnly = False,
                    pUse3DDistOnly = True)
                    
            elif pConvType == "spatial":
                
                curInFeatures = tf.gather(curInFeatures, grid.sortedIndices_)
                curInFeatures = self.hProjFactory_.create_convolution(
                    pConvName = bnName+"_Conv_2",
                    pNeighborhood = neigh, 
                    pFeatures = curInFeatures,
                    pNumOutFeatures = pOutNumFeatures//4,
                    pPtGradients = False,
                    pWeightSpectralNorm = None,
                    pWeightRegCollection = self.weightRegCollection_,
                    numBasis = pNumBasis)

            elif pConvType == "geo":
                
                curInFeatures = tf.gather(curInFeatures, grid.sortedIndices_)
                curInFeatures = self.molConvFactory_.create_convolution( 
                    pConvName = bnName+"_Conv_2",
                    pNeighborhood = neigh, 
                    pTopoDists = topoDists,
                    pFeatures = curInFeatures,
                    pNumOutFeatures = pOutNumFeatures//4,
                    pWeightRegCollection = self.weightRegCollection_,
                    pNumBasis = pNumBasis,
                    pTopoOnly = True)

            elif pConvType == "geo_graph2":
                
                if not(pMolecule.graph2_ is None):
                    auxDists = tf.reshape(topoDists[:, 1], [-1, 1])

                    curInFeatures = tf.gather(curInFeatures, grid.sortedIndices_)
                    curInFeatures = self.molConvFactory_.create_convolution( 
                        pConvName = bnName+"_Conv_2",
                        pNeighborhood = neigh, 
                        pTopoDists = auxDists,
                        pFeatures = curInFeatures,
                        pNumOutFeatures = pOutNumFeatures//4,
                        pWeightRegCollection = self.weightRegCollection_,
                        pNumBasis = pNumBasis,
                        pTopoOnly = True)

            elif pConvType == "graph":
                
                curGraph = pMolecule.graph_
                if not(pMolecule.graph2_ is None):
                    curGraph = pMolecule.graph2_
                curInFeatures = self.graphConvBuilder_.create_graph_aggregation(
                    pInFeatures = curInFeatures, 
                    pGraph = curGraph, 
                    pNormalize = True, 
                    pSpectralApprox = True)

            elif pConvType == "graph_spatial":

                auxGraph = MCGraph(neigh.originalNeighIds_, neigh.samplesNeighRanges_)
                curInFeatures = self.graphConvBuilder_.create_graph_aggregation(
                    pInFeatures = curInFeatures, 
                    pGraph = auxGraph, 
                    pNormalize = True, 
                    pSpectralApprox = True)

            curInFeatures = pBNAFDO(curInFeatures, bnName+"_Conv_2_BAD",
                pApplyBN = True, pApplyNoise = True, pApplyAF = True, pApplyDO = False)

            #Create the third convolution of the block.
            curInFeatures = self.create_1x1_convolution(curInFeatures, pOutNumFeatures,
                bnName+"_Conv_1x1_2")

            #Batch norm.
            curInFeatures = pBNAFDO(curInFeatures, bnName+"_Out_B",
                pApplyBN = True, pApplyNoise = False, pApplyAF = False, pApplyDO = False)

            #If the number of input features is different than the desired output
            if blockInFeatures.get_shape()[-1] != pOutNumFeatures:
                blockInFeatures = pBNAFDO(blockInFeatures, bnName+"_Shortcut_AD",
                    pApplyBN = False, pApplyNoise = True, pApplyAF = True, pApplyDO = True)
                blockInFeatures = self.create_1x1_convolution(blockInFeatures, pOutNumFeatures,
                    bnName+"_Conv_1x1_Shortcut")
                blockInFeatures = pBNAFDO(blockInFeatures, bnName+"_Shortcut_Out_B",
                    pApplyBN = True, pApplyNoise = False, pApplyAF = False, pApplyDO = False)

            #Add the new features to the input features
            curInFeatures = curInFeatures + blockInFeatures

        #Return the resulting features.
        return curInFeatures
