import os
import sys
import numpy as np
import tensorflow as tf

from IEProtLib.pc import MCAABB
from IEProtLib.pc import MCGrid
from IEProtLib.pc import MCKnn

def cross_entropy(pLabels, pLogits, pBatchIds = None, pBatchSize = None, pWeights =None):
    """Method to compute the cross entropy loss.

    Args:
        pLabels (tensor nx1): Tensor with the labels of each object.
        pLogits (tensor nxc): Tensor with the logits to compute the probabilities.
    Returns:
        (tensor): Cross entropy loss.
    """

    labels = tf.cast(pLabels, tf.int64)

    if not(pWeights is None):
        cross_entropy = tf.losses.sparse_softmax_cross_entropy(
            labels=labels, logits=pLogits[0:tf.shape(labels)[0], :], 
            weights=pWeights, reduction=tf.losses.Reduction.NONE)
    else:
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=pLogits[0:tf.shape(labels)[0], :], 
            name='XEntropy')
    
    if not(pBatchIds is None):
        cross_entropy = tf.math.unsorted_segment_mean(
            tf.reshape(cross_entropy, [-1, 1]), pBatchIds, pBatchSize)    
    
    return tf.reduce_mean(cross_entropy, name='XEntropy_Mean')


def binary_cross_entropy(pLabels, pLogits, pBatchIds = None, pBatchSize = None, pPosWeight =None):
    """Method to compute the binary cross entropy loss.

    Args:
        pLabels (tensor nx1): Tensor with the labels of each object.
        pLogits (tensor nx1): Tensor with the logits to compute the probabilities.
    Returns:
        (tensor): Cross entropy loss.
    """
    labels = tf.reshape(tf.cast(pLabels, tf.float32), [-1])
    logits = tf.reshape(pLogits[0:tf.shape(pLabels)[0]], [-1])

    if pPosWeight is None:
        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, 
            logits=logits, name='BinaryXEntropy')
    else:
        cross_entropy = tf.nn.weighted_cross_entropy_with_logits(targets=labels, 
            logits=logits, name='BinaryXEntropy', pos_weight = pPosWeight)

    if not(pBatchIds is None):
        cross_entropy = tf.math.unsorted_segment_mean(
            tf.reshape(cross_entropy, [-1, 1]), pBatchIds, pBatchSize)    

    return tf.reduce_mean(cross_entropy, name='BinaryXEntropy_Mean')


def binary_focal_loss(pLabels, pLogits, 
    pBatchIds = None, pBatchSize = None, 
    pAlpha = 0.75, pGamma = 2.0):
    """Method to compute the binary focal loss.

    Args:
        pLabels (tensor nx1): Tensor with the labels of each object.
        pLogits (tensor nx1): Tensor with the logits to compute the probabilities.
        pAlpha (float): Alpha parameter.
        pGamma (float): Gamma parameter.
    Returns:
        (tensor): Cross entropy loss.
    """
    labels = tf.reshape(pLabels, [-1])
    logits = tf.nn.sigmoid(tf.reshape(pLogits, [-1])[0:tf.shape(pLabels)[0]])

    pt1 = tf.where(tf.equal(labels, 1), logits, tf.ones_like(logits))
    pt0 = tf.where(tf.equal(labels, 0), logits, tf.zeros_like(logits))

    epsilon = 1e-6
    pt1 = tf.clip_by_value(pt1, epsilon, 1.0-epsilon)
    pt0 = tf.clip_by_value(pt0, epsilon, 1.0-epsilon)
    
    focalLoss = -pAlpha*tf.pow(1.0 - pt1, pGamma)*tf.log(pt1) \
        -(1.0-pAlpha)*tf.pow(pt0, pGamma)*tf.log(1.0-pt0)

    if not(pBatchIds is None):
        focalLoss = tf.math.unsorted_segment_mean(
            tf.reshape(focalLoss, [-1, 1]), pBatchIds, pBatchSize)
        
    return tf.reduce_mean(focalLoss, name='FocalLoss_Mean')


def multiclass_binary_cross_entropy(pLabels, pLogits, pNumCategories,
    pBatchIds = None, pBatchSize = None, pPosWeight =None, pSamplingProbs = None):
    """Method to compute the binary cross entropy loss for multiclass.

    Args:
        pLabels (tensor nx1): Tensor with the labels of each object.
        pLogits (tensor nx1): Tensor with the logits to compute the probabilities.
        pNumCategories (int): Number of categories.
        pPosWeight (tensor c): Weight for each category.
        pSamplingProbs (tensor nxc): Probabilities of selecting each prediction for
            the loss computation.
    Returns:
        (tensor): Cross entropy loss.
    """
    labels = tf.cast(pLabels, tf.float32)
    logits = pLogits[0:tf.shape(labels)[0], 0:pNumCategories]
    
    if not(pPosWeight is None):
        auxWeights = tf.reshape(pPosWeight, [1, -1])
        cross_entropy = tf.nn.weighted_cross_entropy_with_logits(
            targets=labels, logits=logits, pos_weight=auxWeights)
            
        if not(pSamplingProbs is None):
            rndNumbers = tf.random.uniform(tf.shape(logits), minval=0.0, maxval=1.0)
            cross_entropy = tf.where(tf.less(rndNumbers, pSamplingProbs), x=cross_entropy, y=tf.zeros_like(cross_entropy))

    elif not(pSamplingProbs is None):
        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, 
            logits=logits, name='BinaryXEntropy')
        rndNumbers = tf.random.uniform(tf.shape(logits), minval=0.0, maxval=1.0)
        cross_entropy = tf.where(tf.less(rndNumbers, pSamplingProbs), x=cross_entropy, y=tf.zeros_like(cross_entropy))
    else:
        cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, 
            logits=logits, name='BinaryXEntropy')
        
    if not(pBatchIds is None):
        cross_entropy = tf.math.unsorted_segment_mean(
            tf.reduce_sum(tf.reshape(cross_entropy, [-1, 1]), axis=-1), pBatchIds, pBatchSize)

    return tf.reduce_mean(cross_entropy, name='BinaryXEntropy_Mean')


def multiclass_binary_focal_loss(pLabels, pLogits, pNumCategories,
    pBatchIds = None, pBatchSize = None, pAlpha = 0.75, pGamma = 2.0):
    """Method to compute the binary focal loss.

    Args:
        pLabels (tensor nx1): Tensor with the labels of each object.
        pLogits (tensor nx1): Tensor with the logits to compute the probabilities.
        pNumCategories (int): Number of cathegories.
        pAlpha (float): Alpha parameter.
        pGamma (float): Gamma parameter.
    Returns:
        (tensor): Cross entropy loss.
    """
    labels = tf.cast(pLabels, tf.float32)
    logits = pLogits[0:tf.shape(labels)[0], 0:pNumCategories]
    logits = tf.nn.sigmoid(logits)

    pt1 = tf.where(tf.equal(pLabels, 1), logits, tf.ones_like(logits))
    pt0 = tf.where(tf.equal(pLabels, 0), logits, tf.zeros_like(logits))

    epsilon = 1e-6
    pt1 = tf.clip_by_value(pt1, epsilon, 1.0-epsilon)
    pt0 = tf.clip_by_value(pt0, epsilon, 1.0-epsilon)
    
    focalLoss = -pAlpha*tf.pow(1.0 - pt1, pGamma)*tf.log(pt1) \
        -tf.pow(pt0, pGamma)*tf.log(1.0-pt0)

    if not(pBatchIds is None):
        focalLoss = tf.math.unsorted_segment_mean(
            tf.reduce_sum(tf.reshape(focalLoss, [-1, 1]), axis=-1), pBatchIds, pBatchSize)
        
    return tf.reduce_mean(focalLoss, name='FocalLoss_Mean')

