import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import tensorflow as tf

def LS_1(w,M,r_1,r_2):
    
    LS_1_value_one = tf.reduce_sum(tf.sign(np.dot(w,M.T)))
    LS_1_value_two = tf.reduce_sum(tf.sign(np.dot(-w,M.T)))
    LS_1_value = max(LS_1_value_one,LS_1_value_two)/(r_1*r_2)
    
    return LS_1_value

def LS_2(w,M,r_1,r_2):
    
    L2_up = np.abs(np.sum(np.dot(w,M.T)))
    L2_down = np.sum(np.abs(np.dot(w,M.T)))
    LS_2_value = L2_up/L2_down
    
    return LS_2_value
    
def J_w(w,M,r_1,r_2):
    
    J_w_up = np.square(np.sum(np.dot(w,M.T)))
    J_w_down = np.sum(np.square(np.dot(w,M.T)))
    J_w_value = J_w_up/J_w_down
    
    return J_w_value

def LDA(w,new_data_A,new_data_B,r_1,r_2):
    
    w = w.T
    
    A_class = new_data_A.numpy()
    B_class = new_data_B.numpy()
    A_class = A_class.T
    B_class = B_class.T

    nu_a = np.mean(A_class,axis=1).reshape(-1,1)
    nu_b = np.mean(B_class,axis=1).reshape(-1,1)

    A_c = A_class-np.repeat(nu_a,A_class.shape[1],axis=1)
    B_c = B_class-np.repeat(nu_b,B_class.shape[1],axis=1)

    S_w = np.dot(A_c,A_c.T)+np.dot(B_c,B_c.T)
    
    S_b = np.dot(nu_a-nu_b,(nu_a-nu_b).T)


    LDA_value_up = np.dot(np.dot(w.T,S_b),w)
    LDA_value_down = np.dot(np.dot(w.T,S_w),w)
    LDA_value = LDA_value_up/LDA_value_down
    
    return LDA_value


def W(original_X,original_Y):

    original_X = tf.constant(original_X)
    original_Y = tf.constant(original_Y)

    data_A = tf.gather(original_X, axis=0, indices=tf.where(original_Y==1)[:,0])
    data_B = tf.gather(original_X, axis=0, indices=tf.where(original_Y==0)[:,0])

    r_1 = len(data_A)
    r_2 = len(data_B)

    new_data_A = tf.reshape(data_A,(data_A.shape[0],-1))
    new_data_B = tf.reshape(data_B,(data_B.shape[0],-1)) 
        
#     if new_data_A.shape[1]>=1000 or new_data_B.shape[1]>=1000:
#         new_data_A = tf.reduce_mean(data_A,axis=-1)
#         new_data_B = tf.reduce_mean(data_B,axis=-1)
#         new_data_A = tf.reshape(new_data_A,(new_data_A.shape[0],-1))
#         new_data_B = tf.reshape(new_data_B,(new_data_B.shape[0],-1))
     
        
    M = np.zeros((r_1*r_2,new_data_A.shape[1]),dtype='float32')

    index_base = np.arange(r_2)

    for i in range(r_1):
        M[index_base+i*r_2,:]=new_data_A[i]-new_data_B  



    m =np.sum(M,axis=0).reshape(1,-1)
    w = m/np.linalg.norm(m)
    
    LS_1_value = LS_1(w,M,r_1,r_2)
    LS_2_value = LS_2(w,M,r_1,r_2)
    J_w_value = J_w(w,M,r_1,r_2)
    LDA_value = LDA(w,new_data_A,new_data_B,r_1,r_2)
    
    return LS_1_value,LS_2_value,J_w_value,LDA_value


# 每层名字

def get_layer_name(model):
    
    layer_name_list = []
    number_squence = [0,'1st','2nd','3rd','4th','5th','6th','7th','8th','9th','10th','11th','12th','13th','14th','15th','16th','17th','18th','19th','20th','21th','22th','23th','24th','25th']

    name_conts = np.ones((13,),dtype='int')
    

    for layer in model.layers: 

        if 'conv2d' in layer.name: 
            name_conts_layer = name_conts[0]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'conv2d')
            name_conts[0]+=1

        if 'max_pooling' in layer.name: 
            name_conts_layer = name_conts[1]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'max_pooling')
            name_conts[1]+=1

        if 'flatten' in layer.name: 
            name_conts_layer = name_conts[2]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'flatten')    
            name_conts[2]+=1

        if 'dense' in layer.name:
            name_conts_layer = name_conts[3]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'dense')  
            name_conts[3]+=1

        if 'batch_normalization' in layer.name: 
            name_conts_layer = name_conts[4]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'batch_normalization')     
            name_conts[4]+=1

        if 'dropout' in layer.name: 
            name_conts_layer = name_conts[5]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'dropout')  
            name_conts[5]+=1

        if 'add' in layer.name: 
            name_conts_layer = name_conts[6]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'add')  
            name_conts[6]+=1

        if 'activation' in layer.name: 
            name_conts_layer = name_conts[7]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'activation')  
            name_conts[7]+=1            

        if 'average_pooling2d' in layer.name: 
            name_conts_layer = name_conts[8]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'average_pooling2d')  
            name_conts[8]+=1       

        if 'multi_head_attention' in layer.name: 
            name_conts_layer = name_conts[9]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'multi_head_attention')  
            name_conts[9]+=1   

        if 'layer_normalization' in layer.name: 
            name_conts_layer = name_conts[10]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'layer_normalization')     
            name_conts[10]+=1            

        if 'patches' in layer.name: 
            name_conts_layer = name_conts[11]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'patches')     
            name_conts[11]+=1                    

        if 'patch_encoder' in layer.name: 
            name_conts_layer = name_conts[12]
            layer_name_list.append(number_squence[name_conts_layer]+'_'+'patch_encoder')     
            name_conts[12]+=1  
            
    return layer_name_list





def plot_Separability_figure(layer_name_list,x_plot,LS_1_squence,LS_2_squence,J_w_squence,LDA_squence,LS_1_squence_base,LS_2_squence_base,J_w_squence_base,LDA_squence_base):
    
    
    #figure = plt.figure(1,figsize=(8,8))
    figure = plt.figure(figsize=(16,10))
    #layer_name_list = get_layer_name(model)[reserved_layers:]
    
    
    # LS_0
    ax1 = plt.subplot(2,2,1)

    for i in range(len(LS_1_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LS_1_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LS_1_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LS}_1$')
    #plt.legend(bbox_to_anchor=(1,0),loc=3)


    # LS_1
    ax2 = plt.subplot(2,2,2)

    for i in range(len(LS_2_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LS_2_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LS_2_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LS}_2$')
    plt.legend(bbox_to_anchor=(1,0),loc=3)


    # LS_2
    ax3 = plt.subplot(2,2,3)

    for i in range(len(J_w_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,J_w_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,J_w_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{J}_w$')
    #plt.legend(bbox_to_anchor=(1,0),loc=3)

    # J_w
    ax4 = plt.subplot(2,2,4)

    for i in range(len(LDA_squence)):
        if 'flatten' not in layer_name_list[i]:
            plt.plot(x_plot,LDA_squence[i,:],label=layer_name_list[i])
    plt.plot(x_plot,LDA_squence_base[0,:],'-k',label='base_line')
    plt.title('$\mathrm{LDA}$')
    plt.legend(bbox_to_anchor=(1,0),loc=3)
    
    plt.tight_layout()
    
    
    
    return figure
    
    
def plot_net_figure(layer_name_list,x_plot,train_loss_squence,train_accuracy_squence,test_loss_squence,test_accuracy_squence):
    
    
    #figure = plt.figure(1,figsize=(8,8))
    figure = plt.figure(figsize=(16,10))
    #layer_name_list = get_layer_name(model)[reserved_layers:]
    
    ax1 = plt.subplot(2,2,1)
    plt.plot(x_plot,train_loss_squence)
    plt.title('Train Loss')

    ax2 = plt.subplot(2,2,2)
    plt.plot(x_plot,train_accuracy_squence)
    plt.title('Train Accuracy')

    ax3 = plt.subplot(2,2,3)
    plt.plot(x_plot,test_loss_squence)
    plt.title('Test Loss')


    ax4 = plt.subplot(2,2,4)
    plt.plot(x_plot,test_accuracy_squence)
    plt.title('Test Accuracy')

    
    plt.tight_layout()

    
    
    return figure


