import numpy as np
import tensorflow as tf
import random
import scipy

N = 1000
c = 10
d = 20

ma = 20
m_half = 100

def generate_label(N, gamma):
    label=np.zeros((N,1))
    set1=set(range(int(N*gamma/2)))
    set2=set(range(int(N*gamma)))-set1
    setM=set(range(N))-set1-set2
    for i in set1:
        label[i]=1
    for i in set2:
        label[i]=-1
    for i in setM:
        label[i]=2 * np.random.binomial(1, 0.5, 1) - 1
    return label


def generateAdj(N, gamma, deg, n_st, n_sr, label):
    adj_matrix = np.eye(N)
    set1=set(range(int(N*gamma/2)))
    set2=set(range(int(N*gamma)))-set1
    setM=set(range(N))-set1-set2
    for i in setM:
        if label[i]==1:
            l_star=random.sample(list(set1), n_st)
            l_sharp=random.sample(list(set2), n_sr)
        else:
            l_star=random.sample(list(set2), n_st)
            l_sharp=random.sample(list(set1), n_sr)
        n0=len(np.nonzero(adj_matrix[i,:])[0])
        if n0>deg-n_sr-n_st:
            l_nd=[]
        else:
            l_nd=random.sample(list(setM-set(np.where(adj_matrix[i,:]>0)[0])), int(deg+1-n_st-n_sr-n0))

        adj_matrix[i,l_star+l_sharp+l_nd]=1
        adj_matrix[l_star+l_sharp+l_nd, i]=1

    for i in set1:
        num=deg+1-len(list(np.nonzero(adj_matrix[i,:])[0]))
        if num>0:
            l_in=random.sample(list(set1-set(np.where(adj_matrix[i,:]>0)[0])), num)
            adj_matrix[i, l_in]=1
            adj_matrix[l_in, i]=1
    for i in set2:
        num=deg+1-len(list(np.nonzero(adj_matrix[i,:])[0]))
        if num>0:
            l_in=random.sample(list(set2-set(np.where(adj_matrix[i,:]>0)[0])), num)
            adj_matrix[i, l_in]=1
            adj_matrix[l_in, i]=1
    ep=gamma+(1-gamma)*n_st/(n_st+n_sr)
    deg_matrix = tf.reduce_sum(adj_matrix, axis=-1)
    return adj_matrix, 1-ep


def neighbor_12(A):
    nb1 = []
    nb2 = []
    for i in range(N):
        zero_spd = {i}
        et_list = list(np.nonzero(A[i, :]))
        ls_set = set(et_list[0]) - zero_spd
        nb1.append(list(ls_set))
    """
    for i in range(N):
        ls=[]
        for j in set(nb1[i]):
            ls=list(set(list(nb1[j])+ls))
        zero_spd={i}
        ls_set=set(ls)-set(nb1[i])-zero_spd
        nb2.append(list(ls_set))
    """
    for i in range(N):
        zero_spd = {i}
        ls_set = set(range(N)) - set(nb1[i]) - zero_spd
        nb2.append(list(ls_set))
    return nb1, nb2


def feature_assign(A, nb1, nb2):
    pattern_short = np.random.randint(low=3, high=c + 1, size=(N, 1))
    pattern_short[list(nb1[0])] = 1
    pattern_short[list(nb1[1])] = 2
    for i in set(nb1[0]).intersection(set(nb1[1])):
        pattern_short[i] = np.random.binomial(1, 0.5, 1) + 1
    for i in range(c):
        pattern_short[i] = i + 1
    vote = 0.5 * np.ones((N, 1))
    for i in range(N):
        nb_set = set(nb1[i])
        s1 = set(nb1[0]).intersection(nb_set)
        s2 = set(nb1[1]).intersection(nb_set)
        if len(s1) == 0 and len(s2) == 0:
            r = 0.5
        else:
            r = len(s1) / (len(s1) + len(s2))
        vote[i] = r
    ind1 = np.nonzero(np.maximum(0, vote - 0.5))
    node1_ratio = 1 - vote[list(ind1[0])]
    ind2 = np.nonzero(np.maximum(0, 0.5 - vote))
    node2_ratio = vote[list(ind2[0])]
    det_node_ratio = np.concatenate((node1_ratio, node2_ratio), 0)
    epsilon_S = np.mean(det_node_ratio)
    return vote, det_node_ratio, epsilon_S, pattern_short


def generate_data(nb1, nb2, P_m1, P_m2, sigma, delta, deg, gamma):
    L=int(deg/4)+1
    set1=set(range(int(N*gamma/2)))
    set2=set(range(int(N*gamma)))-set1
    setM=set(range(N))-set1-set2
    pattern_short=[]
    for i in set1:
        pattern_short.append(0)
    for i in set2:
        pattern_short.append(1)
    for i in setM:
        num=np.random.randint(low=3, high=c+1, size=(1,))
        pattern_short.append(int(num)-1)
    data=np.zeros((N,3,L,ma))
    for i in range(N):
        V = np.zeros((ma, L))
        Q = np.zeros((ma, L))
        K = np.zeros((ma, L))
        V[:, 0] = P_m1[:,pattern_short[i]]
        Q[:, 0] = P_m2[:, pattern_short[i]]
        K[:, 0] = P_m2[:, pattern_short[i]]
        pt=np.asarray(pattern_short, dtype=np.int64)
        ls=random.sample(list(pt[nb1[i]]),L-1)
        V[:, 1:L] = P_m1[:, ls]
        Q[:, 1:L] = P_m2[:, ls]
        K[:, 1:L] = P_m2[:, ls]
        """
        ls2=random.sample(list(pt[nb2[i]]),deg)
        V[:, deg+1:2*deg + 1] = P_m1[:, ls2]
        Q[:, deg+1:2*deg + 1] = P_m2[:, ls2]
        K[:, deg+1:2*deg + 1] = P_m2[:, ls2]
        """
        noise_sigma = np.zeros((ma, L))
        index_small = list(np.arange(L))
        norm_ns = []
        for j in set(index_small):

            ns = np.random.normal(0, 0.1 * sigma * sigma, (ma,))
            if np.linalg.norm(ns) < sigma:
                noise_sigma[:, j] = ns
            else:
                noise_sigma[:, j] = ns * sigma / np.linalg.norm(ns)
            norm_ns.append(np.linalg.norm(ns))

        V = V + noise_sigma
        Q = Q + np.random.normal(0, delta * delta, (ma, L))
        K = K + np.random.normal(0, delta * delta, (ma, L))
        data[i,0,:,:]=np.transpose(V)
        data[i,1,:,:]=np.transpose(Q)
        data[i,2,:,:]=np.transpose(K)
    # noise

    return data, pattern_short


def dataset_generate(data, label, test_ind):

    train_ind = set(range(N)) - set(test_ind)


    train_data = data[list(train_ind), :, :, :]
    test_data = data[list(test_ind), :, :, :]
    train_label = label[list(train_ind)]
    test_label = label[list(test_ind)]
    return train_data, train_label, test_data, test_label

class BiasLayer(tf.keras.layers.Layer):
    def __init__(self, pe1_dim, pe2_dim):
        super(BiasLayer, self).__init__()
        self.spd_num1=40
        self.spd_num2=40
        self.pe1_dim=pe1_dim
        self.pe2_dim=pe2_dim
    def build(self, pe1_dim):
        self.bias1 = self.add_weight('bias',
                                    shape=self.pe1_dim,
                                    initializer='zeros',
                                    trainable=True)
        self.bias2 = self.add_weight('bias',
                                    shape=self.pe2_dim,
                                    initializer='zeros',
                                    trainable=True)
    def call(self, x):
        t=tf.concat([10*tf.ones((1)), tf.math.multiply(tf.ones(shape=(self.spd_num1)),self.bias1)],0)
        #t=tf.concat([t, tf.math.multiply(tf.ones(shape=(self.spd_num2)), self.bias2)], 0)
        return x + t

class neural_netowrk(tf.keras.Model):
    def __init__(self, seed=1):
        super(neural_netowrk, self).__init__()
        # use random seed to make the initialization repeat
        tf.random.set_seed(seed)
        # define convolutional layers
        self.spd1_num=40
        self.spd2_num=40
        self.fc1 = tf.keras.layers.Dense(m_half, kernel_initializer=tf.keras.initializers.RandomNormal(0, 1e-2),
                                         activation='relu')
        self.fc2 = tf.keras.layers.Dense(m_half, kernel_initializer=tf.keras.initializers.RandomNormal(0, 1e-2),
                                         activation='relu')
        self.fc3 = tf.keras.layers.Dense(int(m_half * 2),
                                         kernel_initializer=tf.keras.initializers.RandomNormal(0, 1e-2),
                                         activation='relu')
        self.in_v = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)),
                                          activation=None, name='in_v')
        self.in_q = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)),
                                          activation=None, name='in_q')
        self.in_k = tf.keras.layers.Dense(units=ma, kernel_initializer=tf.keras.initializers.Constant(tf.eye(ma)),
                                          activation=None, name='in_k')

        #self.fc3 = tf.keras.layers.Dense(10, activation='softmax')
        #self.pe = BiasLayer(1,1)
        """
    def build(self):
        self.pe1 = self.add_weights('pe1', shape=(1,), initializer='zeros', trainable=True)
        self.pe2 = self.add_weights('pe2', shape=(1,), initializer='zeros', trainable=True)
    """
    def call(self, input):
        '''
        here we define the forward function
        :param input: the input data
        :return: output tensor
        '''
        # For each layer, a bias will also be initialized and add to the output after matrix multiply.
        v=self.in_v(input[:,0])
        v=tf.transpose(v,perm=[0,2,1])
        x=self.fc3(v)
        x=tf.transpose(x,perm=[0,2,1])
        #distribution = tf.nn.softmax(scores)
        #x = tf.matmul(distribution, v)
        #x = tf.keras.layers.Attention()([q, v, k])
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = tf.math.reduce_mean(x1, 2) - tf.math.reduce_mean(x2, 2)
        # x=self.fc3(x)
        output = x[:,0]
        dimen = output.shape[0]

        return tf.reshape(output, shape=[dimen, 1])

def main_iteration():
    Tmax=20
    sp=np.zeros((10,11))
    for i in range(0,11):

        gamma = np.math.sqrt(1/(i+6))

        deg=120
        L=int(deg/2)
        epsilonS=0.2
        n_sr=int(deg*epsilonS/1.1)
        n_st=int(deg/1.1-n_sr)
        for j in range(0,10):
            T_epoch = j * 2
            train_N=800
            count=0
            for t in range(Tmax):

                label=generate_label(N, gamma)
                A, epsilon_S = generateAdj(N, gamma, deg, n_st, n_sr, label)  # fix p1 and change p2 to control epsilon_S, change them together to control gamma
                nb1, nb2 = neighbor_12(A)

                test_num = N-train_N
                test_ind = random.sample(list(range(N)), test_num)
                sigma = 0.1
                delta = 0.2
                P_m1 = scipy.linalg.orth(np.random.normal(0, 1, (ma, c)))
                P_m2 = scipy.linalg.orth(np.random.normal(0, 1, (ma, c)))
                data, attern_short=generate_data(nb1, nb2, P_m1, P_m2, sigma, delta, deg, gamma)
                train_data, train_label, test_data, test_label=dataset_generate(data, label, test_ind)
                print(epsilon_S)
                model = neural_netowrk()
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
                          loss='hinge',metrics=['hinge'])
                model.fit(x=train_data, y=train_label,
                          batch_size=4, validation_split=0.2, epochs=T_epoch)
                # result = model.evaluate(test_data, test_label)
                loss = []
                train_loss = tf.keras.losses.Hinge()(model(train_data), train_label)
                test_loss = tf.keras.losses.Hinge()(model(test_data), test_label)
                print('Iteration', iter, '| Training loss:', train_loss.numpy(), '| Test loss', test_loss.numpy())
                loss.append(test_loss)
                if test_loss<=1e-3:
                    count=count+1
            sp[9-j,i]=count/Tmax
            print(sp)

def main_sp():
    Tmax=20
    sp=np.zeros((10,11))
    for i in range(0,11):

        gamma = np.power(1/(20*i+20),1/4)
        deg=120
        L=int(deg/2)
        epsilonS=0.2
        n_sr=int(deg*epsilonS/1.1)
        n_st=int(deg/1.1-n_sr)
        for j in range(0,10):

            train_N=j*40+560
            count=0
            for t in range(Tmax):

                label=generate_label(N, gamma)
                A, epsilon_S = generateAdj(N, gamma, deg, n_st, n_sr, label)  # fix p1 and change p2 to control epsilon_S, change them together to control gamma
                nb1, nb2 = neighbor_12(A)

                test_num = N-train_N
                test_ind = random.sample(list(range(N)), test_num)
                sigma = 0.1
                delta = 0.2
                P_m1 = scipy.linalg.orth(np.random.normal(0, 1, (ma, c)))
                P_m2 = scipy.linalg.orth(np.random.normal(0, 1, (ma, c)))
                data, attern_short=generate_data(nb1, nb2, P_m1, P_m2, sigma, delta, deg, gamma)
                train_data, train_label, test_data, test_label=dataset_generate(data, label, test_ind)
                print(epsilon_S)
                model = neural_netowrk()
                model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
                          loss='hinge',metrics=['hinge'])
                model.fit(x=train_data, y=train_label,
                          batch_size=4, validation_split=0.2, epochs=20)
                # result = model.evaluate(test_data, test_label)
                loss = []
                train_loss = tf.keras.losses.Hinge()(model(train_data), train_label)
                test_loss = tf.keras.losses.Hinge()(model(test_data), test_label)
                print('Iteration', iter, '| Training loss:', train_loss.numpy(), '| Test loss', test_loss.numpy())
                loss.append(test_loss)
                if test_loss<=1e-3:
                    count=count+1
            sp[9-j,i]=count/Tmax
            print(sp)


if __name__=='__main__':
    main_iteration()
    #main_sp()