import os
import sys
import math
import numpy as np
import tensorflow as tf

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_MODULE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_MODULE_DIR, "tf_ops"))

from IEProtLibModule import basis_proj

from IEProtLib.pc import MCConvFactory
from IEProtLib.tf_utils import spectral_norm
from IEProtLib.py_utils.py_pc import uniform_samples_surface_sphere, farthest_point_sampling

class MCHProjFactory:
    """Factory to create a convolution with half projections.
    
    Attributes:
        staticRandomState_ (numpy.random.RandomState): Static random state.
    """

    staticRandomState_ = np.random.RandomState(None)


    def __create_convolution__(self, 
        pConvName,
        pNeighborhood, 
        pFeatures,
        pNumOutFeatures,
        pPtGradients,
        pWeightSpectralNorm,
        pWeightRegCollection,
        pActivation,
        **kwargs):
        """Function to create a spatial convolution.

        Args:
            pConvName (string): String with the name of the convolution.
            pNeighborhood (MCNeighborhood): Input neighborhood.
            pFeatures (float tensor nxf): Input point cloud features.
            pNumOutFeatures (int): Number of output features.
            pPtGradients (bool): Boolean that indicates if the gradients for the
                point coordinates are computed.
            pWeightSpectralNorm (bool): Boolean that indicates if the weights will be 
                normalized used spectral norm.
            pWeightRegCollection (string): Weight regularization collection name.
            pActivation (int): Activation function used in the kernel 
                (0: RELU, 1: LRELU, 2: ELU).
            kwargs (list of parameters): Parameters to pass to each specific kernel.

        Returns:
            n'pNumOutFeatures tensor: Tensor with the result of the convolution.
        """

        #Get the number of projection vectors.
        numVectors = 16
        if kwargs is not None:
            for key, value in kwargs.items():
                if key == 'numBasis':
                    numVectors = value

        #Check if the number of projection vectors has an acceptable value.
        if numVectors != 8 and numVectors != 16 and numVectors != 32:
            raise RuntimeError('The number of projection vectors should be 8, 16, or 32.')

        #Compute the number of kernels.
        numInFeatures = pFeatures.shape.as_list()[1]
        numKernels = numInFeatures*pNumOutFeatures

        #Compute the weight initialization parameters.
        numDims = pNeighborhood.pcSamples_.pts_.shape.as_list()[1]

        #Define the tensorflow variable.
        stdDev = math.sqrt(1.0/3.0)
        hProjVecTF = tf.get_variable(pConvName+'_half_proj_vectors', shape=[numVectors, numDims], 
            initializer=tf.initializers.truncated_normal(stddev=stdDev), dtype=tf.float32, trainable=True)
        hProjBiasTF = tf.get_variable(pConvName+'_half_proj_biases', shape=[numVectors, 1], 
            initializer=tf.initializers.zeros(), dtype=tf.float32, trainable=True)
        basisTF = tf.concat([hProjVecTF, hProjBiasTF], axis = 1)

        # rndPts = uniform_samples_surface_sphere(np.random.RandomState(None), 1023, numDims)
        # fpVectors = farthest_point_sampling(rndPts, numVectors)
        # hProjVecTF = tf.get_variable(pConvName+'_half_proj_vectors', shape=[numVectors, numDims], 
        #     initializer=tf.constant_initializer(fpVectors), dtype=tf.float32, trainable=False)
        # hProjBiasTF = tf.get_variable(pConvName+'_half_proj_biases', shape=[numVectors, 1], 
        #     initializer=tf.constant_initializer(-0.5), dtype=tf.float32, trainable=False)
        # basisTF = tf.concat([hProjVecTF, hProjBiasTF], axis = 1)

        #Create the weights.
        stdDev = math.sqrt(2.0/float(numVectors*numInFeatures))
        weights = tf.get_variable(pConvName+'_conv_weights', shape=[numVectors * numInFeatures, pNumOutFeatures], 
            initializer=tf.initializers.truncated_normal(stddev=stdDev), dtype=tf.float32, trainable=True)

        #Apply spectral norm.
        if pWeightSpectralNorm:
            weights = spectral_norm(weights)
            
        #Get the input features projected on the kernel point basis.
        inWeightFeat = basis_proj(pNeighborhood, pFeatures, basisTF, 2+pActivation, pPtGradients)
            
        #Compute the convolution.
        convFeatures = tf.matmul(tf.reshape(inWeightFeat, [-1, numInFeatures*numVectors]), weights)

        #Add to collection for weight regularization.
        tf.add_to_collection(pWeightRegCollection, tf.reshape(weights, [-1]))

        return convFeatures



class MCHProjRELUFactory(MCConvFactory, MCHProjFactory):
    """Factory to create a convolution with half projections and RELU activation functions.
    
    Attributes:
        staticRandomState_ (numpy.random.RandomState): Static random state.
    """

    staticRandomState_ = np.random.RandomState(None)


    def get_kernel_type(self):
        """Function to obtain the name of the convolution factory.

        Returns:
            string: Name of the convolution factory.
        """
        return "hprelu"

    def create_convolution(self, 
        pConvName,
        pNeighborhood, 
        pFeatures,
        pNumOutFeatures,
        pPtGradients,
        pWeightSpectralNorm,
        pWeightRegCollection,
        **kwargs):
        """Function to create a spatial convolution.

        Args:
            pConvName (string): String with the name of the convolution.
            pNeighborhood (MCNeighborhood): Input neighborhood.
            pFeatures (float tensor nxf): Input point cloud features.
            pNumOutFeatures (int): Number of output features.
            pPtGradients (bool): Boolean that indicates if the gradients for the
                point coordinates are computed.
            pWeightSpectralNorm (bool): Boolean that indicates if the weights will be 
                normalized used spectral norm.
            pWeightRegCollection (string): Weight regularization collection name.
            kwargs (list of parameters): Parameters to pass to each specific kernel.

        Returns:
            n'pNumOutFeatures tensor: Tensor with the result of the convolution.
        """

        return self.__create_convolution__(pConvName,
            pNeighborhood, pFeatures, pNumOutFeatures, pPtGradients,
            pWeightSpectralNorm, pWeightRegCollection, 0, **kwargs)


class MCHProjLRELUFactory(MCConvFactory, MCHProjFactory):
    """Factory to create a convolution with half projections and Leaky-RELU activation functions.

    Attributes:
        staticRandomState_ (numpy.random.RandomState): Static random state.
    """

    staticRandomState_ = np.random.RandomState(None)


    def get_kernel_type(self):
        """Function to obtain the name of the convolution factory.

        Returns:
            string: Name of the convolution factory.
        """
        return "hplrelu"

    def create_convolution(self, 
        pConvName,
        pNeighborhood, 
        pFeatures,
        pNumOutFeatures,
        pPtGradients,
        pWeightSpectralNorm,
        pWeightRegCollection,
        **kwargs):
        """Function to create a spatial convolution.

        Args:
            pConvName (string): String with the name of the convolution.
            pNeighborhood (MCNeighborhood): Input neighborhood.
            pFeatures (float tensor nxf): Input point cloud features.
            pNumOutFeatures (int): Number of output features.
            pPtGradients (bool): Boolean that indicates if the gradients for the
                point coordinates are computed.
            pWeightSpectralNorm (bool): Boolean that indicates if the weights will be 
                normalized used spectral norm.
            pWeightRegCollection (string): Weight regularization collection name.
            kwargs (list of parameters): Parameters to pass to each specific kernel.

        Returns:
            n'pNumOutFeatures tensor: Tensor with the result of the convolution.
        """

        return self.__create_convolution__(pConvName,
            pNeighborhood, pFeatures, pNumOutFeatures, pPtGradients,
            pWeightSpectralNorm, pWeightRegCollection, 1, **kwargs)


class MCHProjELUFactory(MCConvFactory, MCHProjFactory):
    """Factory to create a convolution with half projections and ELU activation functions.

    Attributes:
        staticRandomState_ (numpy.random.RandomState): Static random state.
    """

    staticRandomState_ = np.random.RandomState(None)


    def get_kernel_type(self):
        """Function to obtain the name of the convolution factory.

        Returns:
            string: Name of the convolution factory.
        """
        return "hpelu"

    def create_convolution(self, 
        pConvName,
        pNeighborhood, 
        pFeatures,
        pNumOutFeatures,
        pPtGradients,
        pWeightSpectralNorm,
        pWeightRegCollection,
        **kwargs):
        """Function to create a spatial convolution.

        Args:
            pConvName (string): String with the name of the convolution.
            pNeighborhood (MCNeighborhood): Input neighborhood.
            pFeatures (float tensor nxf): Input point cloud features.
            pNumOutFeatures (int): Number of output features.
            pPtGradients (bool): Boolean that indicates if the gradients for the
                point coordinates are computed.
            pWeightSpectralNorm (bool): Boolean that indicates if the weights will be 
                normalized used spectral norm.
            pWeightRegCollection (string): Weight regularization collection name.
            kwargs (list of parameters): Parameters to pass to each specific kernel.

        Returns:
            n'pNumOutFeatures tensor: Tensor with the result of the convolution.
        """

        return self.__create_convolution__(pConvName,
            pNeighborhood, pFeatures, pNumOutFeatures, pPtGradients,
            pWeightSpectralNorm, pWeightRegCollection, 2, **kwargs)

