import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimage
import numpy as np
import os
from glob import glob
import random
import scipy
from sklearn.metrics import confusion_matrix
import pickle

#os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

d = 5

ma = 10
m_half = 500
L=100


def generate_data(mu, omega1, flag, L, alpha_nd, M, sigma0, sigma, delta, P_m1, P_m2):
    num_nd = int(L * alpha_nd)
    if flag==1:
        o1=omega1
    else:
        o1=L-num_nd-omega1
    separate=[0, o1, L-num_nd,L]
    for i in range(M-3):
        separate.append(random.randint(L-num_nd,L))
    separate.sort()

    V=np.zeros((ma,L))
    Q = np.zeros((ma, L))
    K = np.zeros((ma, L))
    for i in range(M):
        V[:,separate[i]:separate[i+1]]=np.dot(P_m1[:,i].reshape(ma,1),np.ones((1,separate[i+1]-separate[i])))
        Q[:, separate[i]:separate[i + 1]] = np.dot(P_m2[:,i].reshape(ma,1), np.ones((1, separate[i + 1] - separate[i])))
        K[:, separate[i]:separate[i + 1]] = np.dot(P_m2[:,i].reshape(ma,1), np.ones((1, separate[i + 1] - separate[i])))
    noise_sigma=np.zeros((ma,L))
    index_small = list(np.arange(L))
    for i in set(index_small):

        ns=np.random.normal(0,0.1*sigma*sigma,(ma,))
        if np.linalg.norm(ns)<sigma:
            noise_sigma[:,i]=ns
        else:
            noise_sigma[:,i]=ns*sigma/np.linalg.norm(ns)
    V=V+noise_sigma
    Q=Q+np.random.normal(0,delta*delta,(ma,L))
    K=K+np.random.normal(0,delta*delta,(ma,L))
    return np.transpose(V),np.transpose(Q),np.transpose(K)

def batch_pred(model, data, index,L):
    y_pred = []
    for i in set(index):
        a = data[i]
        y_pred.append(model(a.reshape(1, 3, L, ma)))
    y_pred = np.array(y_pred)
    return y_pred.reshape(len(index),1)

def hg_loss(y_true, y_pred):
    loss = tf.reduce_mean(tf.nn.relu(1 - y_pred * y_true))
    return loss

class neural_netowrk(tf.keras.Model):
    def __init__(self, seed=1):
        super(neural_netowrk, self).__init__()
        # use random seed to make the initialization repeat
        tf.random.set_seed(seed)
        # define convolutional layers
        self.fc1 = tf.keras.layers.Dense(m_half,kernel_initializer=tf.keras.initializers.RandomNormal(0,1e-2), activation='relu')
        self.fc2 = tf.keras.layers.Dense(m_half, kernel_initializer=tf.keras.initializers.RandomNormal(0,1e-2),activation='relu')
        self.fc3 = tf.keras.layers.Dense(int(m_half*2), kernel_initializer=tf.keras.initializers.RandomNormal(0, 1e-2),activation='relu')
        self.in_v = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_v')
        self.in_q = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_q')
        self.in_k = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_k')

        self.fc3 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, input):
        '''
        here we define the forward function
        :param input: the input data
        :return: output tensor
        '''
        # For each layer, a bias will also be initialized and add to the output after matrix multiply.
        v=self.in_v(input[:,0])
        q=self.in_q(input[:,1])
        k=self.in_k(input[:,2])
        x = tf.keras.layers.Attention()([q,v,k])
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = tf.math.reduce_mean(x1,2)-tf.math.reduce_mean(x2,2)
        #x=self.fc3(x)
        output=tf.math.reduce_mean(x,1)
        dimen=output.shape[0]

        return tf.reshape(output, shape=[dimen,1])


class CNN_V1(tf.keras.Model):
    def __init__(self, seed=1):
        super(CNN_V1, self).__init__()
        # use random seed to make the initialization repeat
        tf.random.set_seed(seed)
        # define convolutional layers
        self.fc1 = tf.keras.layers.Dense(m_half,kernel_initializer=tf.keras.initializers.RandomNormal(0,1e-2), activation='relu')
        self.fc2 = tf.keras.layers.Dense(m_half, kernel_initializer=tf.keras.initializers.RandomNormal(0,1e-2),activation='relu')
        self.fc3 = tf.keras.layers.Dense(int(L), kernel_initializer=tf.keras.initializers.Constant(1/L*tf.ones([L,L])), trainable=False)
        self.in_v = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_v')
        self.in_q = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_q')
        self.in_k = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)), activation=None, name='in_k')

        self.fc4 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, input):
        '''
        here we define the forward function
        :param input: the input data
        :return: output tensor
        '''
        # For each layer, a bias will also be initialized and add to the output after matrix multiply.
        v=self.in_v(input[:,0])
        #q=self.in_q(input[:,1])
        #k=self.in_k(input[:,2])
        #x = tf.keras.layers.Attention()([q,v,k])
        v=tf.transpose(v,perm=[0,2,1])
        x=self.fc3(v)
        x=tf.transpose(x,perm=[0,2,1])
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = tf.math.reduce_mean(x1,2)-tf.math.reduce_mean(x2,2)
        #x=self.fc3(x)
        output=tf.math.reduce_mean(x,1)
        dimen=output.shape[0]

        return tf.reshape(output, shape=[dimen,1])


def main():

    #generate data
    L = 100
    M = 5
    sigma = 0.1
    sigma0=0.1
    delta = 0.2
    alpha_nd = 0.5

    #model= neural_netowrk()
    #y=model(train_data[0:2])
    sample_s=np.zeros((10,9))
    test_max=20
    mu=1
    for i0 in range(120,210):
        #alpha=0.27
        alpha = np.power(1 / i0, 0.25)
        omega1=int(np.round(alpha*L))
        for j in range(10):
            N=int(80*(j+1))
            count=0
            for t in range(test_max):
                #ind = variane_9(train_data, train_label, Nm)

                #train_index, test_index = sample_data(train_data, train_label, test_data, test_label, ind, Nm)
                num_epoch = 10

                P_m1 = scipy.linalg.orth(np.random.normal(0, 1, (ma, M)))
                P_m2 = scipy.linalg.orth(np.random.normal(0, 1, (ma, M)))
                T = 100

                train_data = np.zeros((N, 3, L, ma))
                test_data = np.zeros((T, 3, L, ma))
                train_label = np.zeros((N, 1))
                test_label = np.zeros((T, 1))
                for i in range(N):
                    lb = np.random.binomial(1, 0.5, 1)
                    train_label[i] = lb * 2 - 1
                    train_data[i] = generate_data(mu, omega1, train_label[i], L, alpha_nd, M, sigma0, sigma, delta, P_m1, P_m2)
                for i in range(T):
                    lb = np.random.binomial(1, 0.5, 1)
                    test_label[i] = lb * 2 - 1
                    test_data[i] = generate_data(mu, omega1, test_label[i], L, alpha_nd, M, sigma0, sigma, delta, P_m1, P_m2)

                model = CNN_V1()
                #model=neural_netowrk()
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
                              loss='hinge',
                              metrics=['hinge'])
                model.fit(x=train_data, y=train_label,
                          batch_size=4, validation_split=0.1, epochs=num_epoch)
                #result = model.evaluate(test_data, test_label)

                train_loss=tf.keras.losses.Hinge()(model(train_data),train_label)
                test_loss = tf.keras.losses.Hinge()(model(test_data), test_label)
                print('Iteration', iter, '| Training loss:', train_loss.numpy(), '| Test loss', test_loss.numpy())
                if test_loss<1e-3:
                    count=count+1
            sample_s[9-j,int((i0-120)/10)]=count/test_max
            print('omega1:',omega1,'| j:',j,'| value:', sample_s[9-j,int((i0-120)/10)])
            print(sample_s)


if __name__ == '__main__':
    main()