import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import tensorflow as tf

def LS_star(w, M, r_1, r_2):
    
    dot_product = np.dot(w, M.T)
    
    indicator_positive = np.where(dot_product > 0, 1, 0)
    indicator_negative = np.where(dot_product < 0, 1, 0)

    LS_star_value_one = np.sum(indicator_positive)
    LS_star_value_two = np.sum(indicator_negative)

    LS_star_value = max(LS_star_value_one, LS_star_value_two) / (r_1 * r_2)
    
    return LS_star_value


def LS_0(w,M,r_1,r_2):
    
    LS_0_value_one = tf.reduce_sum(tf.sign(np.dot(w,M.T)))
    LS_0_value_two = tf.reduce_sum(tf.sign(np.dot(-w,M.T)))
    LS_0_value = max(LS_0_value_one,LS_0_value_two)/(r_1*r_2)
    
    return LS_0_value

def LS_1(w,M,r_1,r_2):
    
    L1_up = np.abs(np.sum(np.dot(w,M.T)))
    L1_down = np.sum(np.abs(np.dot(w,M.T)))
    LS_1_value = L1_up/L1_down
    
    return LS_1_value
    
def LS_2(w,M,r_1,r_2):
    
    LS_2_up = np.square(np.sum(np.dot(w,M.T)))
    LS_2_down = np.sum(np.square(np.dot(w,M.T)))
    LS_2_value = LS_2_up/LS_2_down
    
    return LS_2_value

def LDA(w,new_data_A,new_data_B):
    
    w = w.T
    
    A_class = new_data_A.numpy()
    B_class = new_data_B.numpy()
    A_class = A_class.T
    B_class = B_class.T

    nu_a = np.mean(A_class,axis=1).reshape(-1,1)
    nu_b = np.mean(B_class,axis=1).reshape(-1,1)

    A_c = A_class-np.repeat(nu_a,A_class.shape[1],axis=1)
    B_c = B_class-np.repeat(nu_b,B_class.shape[1],axis=1)

    S_w = np.dot(A_c,A_c.T)+np.dot(B_c,B_c.T)
    
    S_b = np.dot(nu_a-nu_b,(nu_a-nu_b).T)


    LDA_value_up = np.dot(np.dot(w.T,S_b),w)
    LDA_value_down = np.dot(np.dot(w.T,S_w),w)
    LDA_value = LDA_value_up/LDA_value_down
    
    return LDA_value


def W(original_X,original_Y):

    original_X = tf.constant(original_X)
    original_Y = tf.constant(original_Y)

    data_A = tf.gather(original_X, axis=0, indices=tf.where(original_Y==1)[:,0])
    data_B = tf.gather(original_X, axis=0, indices=tf.where(original_Y==0)[:,0])

    r_1 = len(data_A)
    r_2 = len(data_B)

    new_data_A = tf.reshape(data_A,(data_A.shape[0],-1))
    new_data_B = tf.reshape(data_B,(data_B.shape[0],-1))

    if new_data_A.shape[1]>=12000 or new_data_B.shape[1]>=12000:
        new_data_A = tf.reduce_mean(data_A,axis=-1)
        new_data_B = tf.reduce_mean(data_B,axis=-1)
        new_data_A = tf.reshape(new_data_A,(new_data_A.shape[0],-1))
        new_data_B = tf.reshape(new_data_B,(new_data_B.shape[0],-1))

    M = np.zeros((r_1*r_2,new_data_A.shape[1]),dtype='float32')

    index_base = np.arange(r_2)

    for i in range(r_1):
        M[index_base+i*r_2,:]=new_data_A[i]-new_data_B  

    m =np.sum(M,axis=0).reshape(1,-1)
    w = m/np.linalg.norm(m)
    
    LS_star_value = LS_star(w,M,r_1,r_2)
    LS_0_value = LS_0(w,M,r_1,r_2)
    LS_1_value = LS_1(w,M,r_1,r_2)
    LS_2_value = LS_2(w,M,r_1,r_2)
    LDA_value = LDA(w,new_data_A,new_data_B)
    
    return LS_star_value,LS_0_value,LS_1_value,LS_2_value,LDA_value


def get_layer_name(model):
    
    layer_name_list = []
    number_squence = [0,'1st','2nd','3rd','4th','5th','6th','7th','8th','9th','10th','11th','12th','13th','14th','15th','16th','17th','18th','19th','20th','21th','22th','23th','24th','25th']

    name_conts = np.ones((11,),dtype='int')
    
    for layer in model.layers: 

        if 'conv2d' in layer.name: 
            name_conts_layer = name_conts[0]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'conv2d')
            name_conts[0]+=1

        if 'max_pooling' in layer.name: 
            name_conts_layer = name_conts[1]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'max_pooling')
            name_conts[1]+=1

        if 'flatten' in layer.name: 
            name_conts_layer = name_conts[2]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'flatten')    
            name_conts[2]+=1

        if 'dense' in layer.name:
            name_conts_layer = name_conts[3]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'dense')  
            name_conts[3]+=1

        if 'batch_normalization' in layer.name: 
            name_conts_layer = name_conts[4]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'batch_normalization')     
            name_conts[4]+=1

        if 'dropout' in layer.name: 
            name_conts_layer = name_conts[5]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'dropout')  
            name_conts[5]+=1

        if 'add' in layer.name: 
            name_conts_layer = name_conts[6]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'add')  
            name_conts[6]+=1

        if 'activation' in layer.name: 
            name_conts_layer = name_conts[7]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'activation')  
            name_conts[7]+=1            

        if 'average_pooling2d' in layer.name: 
            name_conts_layer = name_conts[8]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'average_pooling2d')  
            name_conts[8]+=1       
            
        if 'embedding' in layer.name: 
            name_conts_layer = name_conts[9]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'embedding')  
            name_conts[9]+=1    
            
        if 'conv1d' in layer.name: 
            name_conts_layer = name_conts[10]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'conv1d')  
            name_conts[10]+=1   
            
    return layer_name_list





def plot_Separability_figure(layer_name_list,x_plot,LS_star_squence,LS_0_squence,LS_1_squence,LS_2_squence,LDA_squence,LS_star_squence_base,LS_0_squence_base,LS_1_squence_base,LS_2_squence_base,LDA_squence_base):
    
    figure = plt.figure(figsize=(16,15))
    
    # LS_star
    ax1 = plt.subplot(3,2,1)

    for i in range(len(LS_star_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LS_star_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LS_star_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LS}_*$')
    #plt.legend(bbox_to_anchor=(1,0),loc=3)

    # LS_0
    ax2 = plt.subplot(3,2,2)

    for i in range(len(LS_0_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LS_0_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LS_0_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LS}_0$')
    #plt.legend(bbox_to_anchor=(1,0),loc=3)
    
    # LS_1
    ax3 = plt.subplot(3,2,3)

    for i in range(len(LS_1_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LS_1_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LS_1_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LS}_1$')
    #plt.legend(bbox_to_anchor=(1,0),loc=3)
    
    # LS_2
    ax4 = plt.subplot(3,2,4)

    for i in range(len(LS_2_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LS_2_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LS_2_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LS}_2$')
    #plt.legend(bbox_to_anchor=(1,0),loc=3)
    
    # LS_2
    ax5 = plt.subplot(3,2,5)

    for i in range(len(LDA_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LDA_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LDA_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{JW}$')
    plt.legend(bbox_to_anchor=(1,0),loc=3)
    
    plt.tight_layout()

    return figure
    
    
def plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence):
    
    #figure = plt.figure(1,figsize=(8,8))
    figure = plt.figure(figsize=(16,10))
    #layer_name_list = get_layer_name(model)[reserved_layers:]
    
    ax1 = plt.subplot(2,2,1)
    plt.plot(x_plot,train_loss_squence)
    plt.title('Train Loss')

    ax2 = plt.subplot(2,2,2)
    plt.plot(x_plot,train_accuracy_squence)
    plt.title('Train Accuracy')

    ax3 = plt.subplot(2,2,3)
    plt.plot(x_plot,test_loss_squence)
    plt.title('Test Loss')


    ax4 = plt.subplot(2,2,4)
    plt.plot(x_plot,test_accuracy_squence)
    plt.title('Test Accuracy')

    
    plt.tight_layout()
    
    
    
    return figure


