import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimage
import numpy as np
import os
from glob import glob
import random
import scipy
import scipy.io as scio
from sklearn.metrics import confusion_matrix
import pickle

#os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

d = 20

ma = 20
m_half = 500



def generate_data(mu, sp_method, sampled_num, omega1, flag, L, alpha_nd, M, sigma0, sigma, delta, P_m1, P_m2):
    num_nd = int(L * alpha_nd)
    if flag==1:
        o1=omega1
    else:
        o1=L-num_nd-omega1
    separate=[0, o1, L-num_nd,L]
    for i in range(M-3):
        separate.append(random.randint(L-num_nd,L))
    separate.sort()

    V=np.zeros((ma,L))
    Q = np.zeros((ma, L))
    K = np.zeros((ma, L))
    for i in range(M):
        V[:,separate[i]:separate[i+1]]=np.dot(P_m1[:,i].reshape(ma,1),np.ones((1,separate[i+1]-separate[i])))
        Q[:, separate[i]:separate[i + 1]] = np.dot(P_m2[:,i].reshape(ma,1), np.ones((1, separate[i + 1] - separate[i])))
        K[:, separate[i]:separate[i + 1]] = np.dot(P_m2[:,i].reshape(ma,1), np.ones((1, separate[i + 1] - separate[i])))

    return V,Q,K


def sampling_method(inp, sigma,delta,sp_method, L, omega1, sampled_num):
    [V,Q,K]=inp
    noise_sigma=np.zeros((ma,L))
    index_small = list(np.arange(L))
    norm_ns=[]
    for i in set(index_small):

        ns=np.random.normal(0,0.1*sigma*sigma,(ma,))
        if np.linalg.norm(ns)<sigma:
            noise_sigma[:,i]=ns
        else:
            noise_sigma[:,i]=ns*sigma/np.linalg.norm(ns)
        norm_ns.append(np.linalg.norm(ns))

    V=V+noise_sigma
    Q=Q+np.random.normal(0,delta*delta,(ma,L))
    K=K+np.random.normal(0,delta*delta,(ma,L))

    if sp_method==1:
        # uniform sampling
        sampled_index=random.sample(list(np.arange(L)), sampled_num)
        return np.transpose(V[:,sampled_index]),np.transpose(Q[:,sampled_index]),np.transpose(K[:,sampled_index])
    elif sp_method==2:
        # only remove non-class-discriminative features
        if sampled_num-omega1>0:
            sampled_index = random.sample(list(np.arange(omega1, L)), sampled_num-omega1) + list(np.arange(omega1))
        else:
            sampled_index=list(np.arange(omega1))
        sampled_index.sort()
        return np.transpose(V[:, sampled_index]), np.transpose(Q[:, sampled_index]), np.transpose(
                K[:, sampled_index])
    elif sp_method==3:
        # only denoise
        """
        large_noise_set=set(np.arange(L))-set(index_small)
        if sampled_num-len(index_small)>0:
            sampled_index=index_small+random.sample(list(large_noise_set), sampled_num-len(index_small))
        else:
            sampled_index=index_small
        sampled_index.sort()
        """
        arg_rank=np.argsort(norm_ns)
        sampled_index=list(arg_rank[0:sampled_num])
        return np.transpose(V[:, sampled_index]), np.transpose(Q[:, sampled_index]), np.transpose(K[:, sampled_index])
    elif sp_method==4:
        # reduce all noisy and non-class-discriminative features
        best_set=set(np.arange(omega1))&set(index_small)
        if len(best_set)>=L*0.4 and sampled_num-len(best_set)>0:
            bad_set=set(np.arange(L))-best_set
            sampled_index=list(best_set)+random.sample(list(bad_set),sampled_num-len(best_set))
            sampled_index.sort()
            return np.transpose(V[:, sampled_index]), np.transpose(Q[:, sampled_index]), np.transpose(
                K[:, sampled_index])
        else:
            return False


def batch_pred(model, data, index,L):
    y_pred = []
    for i in set(index):
        a = data[i]
        y_pred.append(model(a.reshape(1, 3, L, ma)))
    y_pred = np.array(y_pred)
    return y_pred.reshape(len(index),1)

def hg_loss(y_true, y_pred):
    loss = tf.reduce_mean(tf.nn.relu(1 - y_pred * y_true))
    return loss

class neural_netowrk(tf.keras.Model):
    def __init__(self, seed=1):
        super(neural_netowrk, self).__init__()
        # use random seed to make the initialization repeat
        tf.random.set_seed(seed)
        # define convolutional layers
        self.fc1 = tf.keras.layers.Dense(m_half,kernel_initializer=tf.keras.initializers.RandomNormal(0,1e-2), activation='relu')
        self.fc2 = tf.keras.layers.Dense(m_half, kernel_initializer=tf.keras.initializers.RandomNormal(0,1e-2),activation='relu')
        self.fc3 = tf.keras.layers.Dense(int(m_half*2), kernel_initializer=tf.keras.initializers.RandomNormal(0, 1e-2),activation='relu')
        self.in_v = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_v')
        self.in_q = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_q')
        self.in_k = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_k')

        self.fc3 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, input):
        '''
        here we define the forward function
        :param input: the input data
        :return: output tensor
        '''
        # For each layer, a bias will also be initialized and add to the output after matrix multiply.
        v=self.in_v(input[:,0])
        q=self.in_q(input[:,1])
        k=self.in_k(input[:,2])
        x = tf.keras.layers.Attention()([q,v,k])
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = tf.math.reduce_mean(x1,2)-tf.math.reduce_mean(x2,2)
        #x=self.fc3(x)
        output=tf.math.reduce_mean(x,1)
        dimen=output.shape[0]

        return tf.reshape(output, shape=[dimen,1])/k.shape[1]


def main():

    #generate data
    L = 50
    M = 20
    sigma = 0.1
    sigma0=0.1
    delta = 0.5
    alpha_nd = 0.35

    #model= neural_netowrk()
    #y=model(train_data[0:2])
    sample_s=np.zeros((8,13))
    test_max=20
    N=80
    mu=0.5
    N_set=[1120,640, 400, 280, 200, 160, 120,80]

    sp_method=1
    sampled_num=L
    alpha = 0.6
    omega1 = int(np.round(alpha * L))
    loss_sampling=[]
    loss_sp_mean=[]
    num_epoch = 10
    P_m1 = scipy.linalg.orth(np.random.normal(0, 1, (ma, M)))
    P_m2 = scipy.linalg.orth(np.random.normal(0, 1, (ma, M)))
    Test = 100
    train_label = np.zeros((N, 1))
    test_label = np.zeros((Test, 1))

    OUT1=[]
    OUT2=[]
    for i in range(N):
        lb = np.random.binomial(1, 0.5, 1)
        train_label[i] = lb * 2 - 1
        out1 = generate_data(mu, sp_method, sampled_num, omega1, train_label[i], L, alpha_nd, M, sigma0,
                                 sigma,
                                 delta, P_m1, P_m2)
        OUT1.append(out1)
    for i in range(Test):
        lb = np.random.binomial(1, 0.5, 1)
        test_label[i] = lb * 2 - 1
        out2 = generate_data(mu, sp_method, sampled_num, omega1, test_label[i], L, alpha_nd, M, sigma0,
                                 sigma,
                                 delta, P_m1, P_m2)
        OUT2.append(out2)

    #[V,Q,K,norm_ns, index_small]=generate_data()

    for i0 in range(1,2):


        for sampled_num in range(L, omega1-1, -2):
            #T=int(2*(j+1))

            count=0
            loss=[]
            for t in range(test_max):
                #ind = variane_9(train_data, train_label, Nm)

                #train_index, test_index = sample_data(train_data, train_label, test_data, test_label, ind, Nm)



                train_data = np.zeros((N, 3, sampled_num, ma))
                test_data = np.zeros((Test, 3, sampled_num, ma))

                for i in range(N):
                    #lb = np.random.binomial(1, 0.5, 1)
                    #train_label[i] = lb * 2 - 1
                    for loop in range(10000):
                        data = sampling_method(OUT1[i],sigma, delta,sp_method,L,omega1,sampled_num)
                        if data != False:
                            train_data[i] = data
                            break
                for i in range(Test):
                    #lb = np.random.binomial(1, 0.5, 1)
                    #test_label[i] = lb * 2 - 1
                    for loop in range(10000):
                        data = sampling_method(OUT2[i],sigma, delta,sp_method,L,omega1,sampled_num)
                        if data != False:
                            test_data[i] = data
                            break
                model = neural_netowrk()
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
                              loss='hinge',
                              metrics=['hinge'])
                model.fit(x=train_data, y=train_label,
                          batch_size=4, validation_split=0.1, epochs=10)
                #result = model.evaluate(test_data, test_label)

                train_loss=tf.keras.losses.Hinge()(model(train_data),train_label)
                test_loss = tf.keras.losses.Hinge()(model(test_data), test_label)
                print('Iteration', iter, '| Training loss:', train_loss.numpy(), '| Test loss', test_loss.numpy())
                loss.append(test_loss)
            loss_sampling.append(np.array(loss))
            loss_sp_mean.append(tf.reduce_mean(np.array(loss)))
            print('sampled_num:',sampled_num,'| loss:',loss_sampling)
            print(loss_sp_mean)
            #print(loss_sampling)
    loss_sampling=np.array(loss_sampling)
    data_path=r"xxx"
    scio.savemat(data_path, {"loss_sampling": loss_sampling})

if __name__ == '__main__':
    main()

