import tensorflow as tf
import keras.backend as K
import math
import numpy as np
from keras import losses

def se_loss(z, zc, batch_size, alpha):
    return (alpha/(2*batch_size))*tf.reduce_sum(tf.pow(tf.subtract(z,zc), 2))

def se_corr_loss(z, zc, batch_size, alpha, sigm):
    N = tf.constant(batch_size,dtype=tf.float32)
    E = tf.subtract(z,zc)
    if sigm=="silverman":
        s = tf.reduce_mean(tf.math.reduce_std(E,0))
        sigma = tf.multiply(tf.multiply(s,1.06),tf.pow(N,-1/5))
    else:
        sigma = tf.multiply(sigm,tf.reduce_max(tf.abs(E)))
    phi_mat = tf.ones(tf.shape(z))-tf.pow(math.e,-tf.pow(E,2)/tf.pow(sigma,2))
    return tf.multiply(alpha/batch_size,tf.reduce_sum(phi_mat))

def c_loss(kernel, batch_size, alpha):
    return alpha*tf.reduce_sum(tf.pow(kernel,2))/batch_size

def reconstruction_loss(input_features,output_features):
    return 0.5 * tf.reduce_sum(tf.pow(tf.subtract(output_features, input_features), 2))

def bd_reg(kernel,k,alpha):
    affinity = 0.5*(tf.abs(kernel)+tf.abs(tf.transpose(kernel)))
    affinity = tf.linalg.set_diag(affinity,np.asarray(np.zeros((affinity.shape[0],)))) #remove diagonal
    rowsum = K.sum(affinity,axis=1)
    
    degree_mat = tf.diag(rowsum) #degree matrix 
    L_A = degree_mat - affinity #laplacian matrix
    eigval,eigvec = tf.linalg.eigh(L_A,name="eigen_bd_reg")
    return alpha*K.sum(tf.abs(eigval[:k]))

def bd_reg_shifted_LA(kernel,k,alpha):
    affinity = tf.multiply(0.5,(tf.abs(kernel)+tf.abs(tf.transpose(kernel))))
    affinity = tf.linalg.set_diag(affinity,np.asarray(np.zeros((affinity.shape[0],)))) #remove diagonal
    rowsum = K.sum(affinity,axis=1)
    rowsum = tf.sqrt(rowsum)
    rowsum = K.pow(rowsum,-1)
    degree_mat = tf.diag(rowsum)
    identity = tf.eye(tf.shape(degree_mat)[0])
    L_A = identity + K.dot(K.dot(degree_mat,affinity),degree_mat) #laplacian matrix
    s = tf.linalg.svd(L_A, compute_uv=False)
    return alpha*(2*k-K.sum(tf.abs(s[-k:])))

def c_a_reg(kernel,alpha):
    affinity = 0.5*(tf.abs(kernel)+tf.abs(tf.transpose(kernel)))
    return (alpha/2)*tf.reduce_sum(tf.pow(tf.subtract(kernel,affinity),2))

def bd_loss(kernel,k,lambd,gamma):
    return c_a_reg(kernel,lambd)+bd_reg(kernel,k,gamma)

def c_q_loss(c, input_labels, lambda_cq, batch_size):
    q = tf.dtypes.cast(tf.argmax(input_labels,axis=1),dtype=tf.float32)
    xx, yy = tf.meshgrid(q,q)
    z = (xx - yy)
    q_norm = tf.abs(tf.clip_by_value(z,-1,1))
    return lambda_cq*K.sum(tf.multiply(q_norm,tf.abs(K.cast(c,dtype='float32'))))/batch_size

def cec_loss(centerLossLayer, y_true, y_pred, batch_size, alpha=6, tau=0.5):
    crossentropy_loss = K.sum(losses.categorical_crossentropy(y_true, y_pred)) #already normalized by batch size
    center_loss = tau*centerLossLayer/batch_size #normalize by batch size
    return alpha*(crossentropy_loss+center_loss)

def crossentropy_loss(y_true, y_pred, alpha=1):
    return alpha*tf.reduce_mean(losses.categorical_crossentropy(y_true, y_pred))

def center_loss(centerLossLayer, batch_size, alpha=1):
    return alpha*centerLossLayer/batch_size #normalize by batch size

def l2_loss(kernel,alpha=1):
    return alpha*tf.reduce_sum(tf.pow(kernel,2.0))
