import tensorflow as tf
from tensorflow.keras import Model, datasets, layers, models
from tensorflow.keras.layers import Dense, Flatten, LayerNormalization 

#  Active party trainable

class_names = ['0', '1', '2', '3', '4', 
               '5', '6', '7', '8', '9']
class_num = 10

class VFLPassiveModel(Model):
    def __init__(self):
        super(VFLPassiveModel, self).__init__()
        self.flatten = layers.Flatten()
        self.d0 = layers.Dense(32, name="dense1", activation='relu')
        self.d1 = layers.Dense(10, name="dense1", activation='relu')

    def call(self, x):
        x = self.flatten(x)
        x = self.d0(x)
        return self.d1(x)


class VFLPassiveModelCIFAR(Model):
    def __init__(self):
        super(VFLPassiveModelCIFAR, self).__init__()
        self.d0 = layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))
        self.d5 = layers.Flatten()
        self.d6= layers.Dense(64, activation='relu')
        self.d7= layers.Dense(10)
      

    def call(self, x):
        x=self.d0(x)
        x=self.d5(x)
        x=self.d6(x)
        x=self.d7(x)
        return  x
    