import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import numpy as np
import copy
class_num= 5

class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):
        x = self.d1(x)
        return self.out(x)


class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(32, name="dense1", activation='relu')

    def call(self, x):
        x = self.flatten(x)
        return self.d1(x)

class newVFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(newVFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

        #self.add_loss(tf.abs(self.w1)+tf.abs(self.w2))

    def call(self, x):

        x = self.d1(x)
        return self.out(x)

def get_poisoned_matrix(passive_matrix, need_poison, poison_grad, amplify_rate=10):
    #print(passive_matrix)
    poisoned_matrix = passive_matrix.numpy()
    poisoned_matrix[need_poison] = poison_grad*amplify_rate
    poisoned_matrix = tf.convert_to_tensor(poisoned_matrix, tf.float32, name='poisoned_matrix')
    return poisoned_matrix

def copy_grad(passive_matrix, need_copy):
    poison_grad = passive_matrix[need_copy].numpy()
    return poison_grad[0]



class VFLActiveModelWithOneLayer(Model):
    def __init__(self):
        super(VFLActiveModelWithOneLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

    def call(self, x):
        x = self.concatenated(x)
        x = self.d1(x)
        return self.out(x)
    
class VFLActiveModelWithTwoLayer(Model):
    def __init__(self):
        super(VFLActiveModelWithTwoLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.d2 = Dense(32, name="dense2", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

    def call(self, x):
        x = self.concatenated(x)
        x = self.d1(x)
        x = self.d2(x)
        return self.out(x)
    
class VFLActiveModelWithThreeLayer(Model):
    def __init__(self):
        super(VFLActiveModelWithThreeLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.d2 = Dense(32, name="dense2", activation='relu')
        self.d3 = Dense(32, name="dense3", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

    def call(self, x):
        x = self.concatenated(x)
        x = self.d1(x)
        x = self.d2(x)
        x = self.d3(x)
        return self.out(x)
    
class VFLActiveModelWithFourLayer(Model):
    def __init__(self):
        super(VFLActiveModelWithFourLayer, self).__init__()
        self.concatenated = tf.keras.layers.Concatenate()
        self.d1 = Dense(32, name="dense1", activation='relu')
        self.d2 = Dense(32, name="dense2", activation='relu')
        self.d3 = Dense(32, name="dense3", activation='relu')
        self.d4 = Dense(32, name="dense4", activation='relu')
        self.out = Dense(class_num, name="out", activation='softmax')

    def call(self, x):
        x = self.concatenated(x)
        x = self.d1(x)
        x = self.d2(x)
        x = self.d3(x)
        x = self.d4(x)
        return self.out(x)

class RAE(Model):
    def __init__(self):
        super(RAE, self).__init__()

        self.d1 = Dense(64, name="dense1", activation='relu')
        self.d2 = Dense(128, name="dense2", activation=None)
        self.d3 = Dense(64, name="dense1", activation='relu')
        self.d4 = Dense(64, name="dense1", activation='relu')
      

    def call(self, x):

        x = self.d3(x)
        x = self.d1(x)
        x2 = LayerNormalization(axis=-1 , center=False , scale=True)(x)
        x = self.d4(x2)
        x = self.d2(x)
        return x,x2

def calculate_l21_rownorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[i,:]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(1))).sum()

def calculate_l21_colnorm(X):
    """
    This function calculates the l21 norm of a matrix X, i.e., \sum ||X[:,j]||_2
    Input:
    -----
    X: {numpy array}
    Output:
    ------
    l21_norm: {float}
    """
    return (np.sqrt(np.multiply(X, X).sum(0))).sum()
