import tensorflow as tf
import numpy as np

@tf.custom_gradient
def bernoulli(input):
    def grad(dy):
        # Straight-through gradient estimator
        return dy * input

    # Perform Bernoulli sampling in the forward pass
    sampled = tf.cast(tf.random.uniform(tf.shape(input)) < input, input.dtype)
    return sampled, grad

@tf.custom_gradient
def bernoulli_dummy(input, mask):
    def grad(dy):
        # Straight-through gradient estimator
        return dy * input, None

    return mask, grad

@tf.custom_gradient
def apply_mask(kernel, mask):
    def grad(dy):
        return None, dy * kernel * mask * (mask * (1 - mask))

    return kernel * mask, grad

class MEInitializer(tf.keras.initializers.Initializer):
    def __call__(self, shape, dtype=None):
        if len(shape) > 2:  # Convolutional layer case
            in_features = np.prod(shape[:-1])  # Product of all dimensions except the last one (number of filters)
        else:  # Dense layer case
            in_features = shape[0]  # The number of input features is the first dimension
        c = np.e * np.sqrt(1 / in_features)
        array = tf.constant([-c, c], dtype=dtype)
        indices = tf.random.categorical(tf.math.log(tf.ones([1, len(array)])), np.prod(shape))
        samples = tf.gather(array, indices[0])
        samples = tf.reshape(samples, shape)
        return samples

class MaskedConv2D(tf.keras.layers.Layer):
    """
    An efficient TensorFlow implementation of a masked convolutional layer.
    """
    def __init__(self, filters, kernel_size, strides=1, padding="same", **kwargs):
        super(MaskedConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding=padding, use_bias=True,
                                           trainable=False, kernel_initializer=MEInitializer(), strides=strides)
        self.built = False

    def build(self, input_shape):
        # Ensures this layer's `build` method is properly called by Keras
        if not self.built:
            self.conv.build(input_shape)
            self.mask = self.add_weight(shape=self.conv.kernel.shape,
                                        initializer='glorot_uniform',
                                        trainable=True,
                                        name='mask')
            if self.conv.use_bias:
                # Initialize the bias mask if bias is used in the conv layer.
                self.bias_mask = self.add_weight(shape=(self.filters,),
                                                 initializer="glorot_uniform",
                                                 trainable=True,
                                                 name="bias_mask")
            self.built = True
        super(MaskedConv2D, self).build(input_shape)  # Ensure to call the parent's `build` to set self.built = True

    def call(self, inputs, masks=None, training=None, verbose=False, md=False):
        # Apply tf.stop_gradient to weights and biases to treat them without gradient
        kernel_no_gradient = tf.stop_gradient(self.conv.kernel)
        bias_no_gradient = tf.stop_gradient(self.conv.bias) if self.conv.use_bias else None

        if md:
            mask_probs = tf.exp(self.mask)
            bias_mask_probs = tf.exp(self.bias_mask) if self.dense.use_bias else None
        else:
            mask_probs = tf.sigmoid(self.mask)
            bias_mask_probs = tf.sigmoid(self.bias_mask) if self.conv.use_bias else None
        if verbose: print("Parameters: ", mask_probs[0][0][0])
        if training:
            # Create binary masks for weights and biases using sampled probabilities
            if masks is not None and masks is not -1:
                mask = masks[self.name]
                binary_mask = bernoulli_dummy(mask_probs, mask)
                if verbose: print("Samples dummy: ", binary_mask[0][0][0])
            else:
                binary_mask = bernoulli(mask_probs)
                if verbose: print("Samples: ", binary_mask[0][0][0])

            masked_kernel = kernel_no_gradient * binary_mask

            if self.conv.use_bias:
                if masks is not None and masks is not -1:
                    bias_mask = masks[self.name + '_bias_mask']
                    binary_bias_mask = bernoulli_dummy(bias_mask_probs, bias_mask)
                else:
                    # Create binary masks for weights and biases using sampled probabilities
                    binary_bias_mask = bernoulli(bias_mask_probs)

                masked_bias = bias_no_gradient * binary_bias_mask
        else:
            # Use probabilities as masks directly without sampling in non-training mode
            masked_kernel = kernel_no_gradient * mask_probs
            masked_bias = bias_no_gradient * bias_mask_probs if self.conv.use_bias else None

        # Compute the convolution using the masked kernel
        outputs = tf.nn.conv2d(inputs, masked_kernel, strides=self.conv.strides, padding=self.conv.padding.upper())

        if self.conv.use_bias:
            outputs = tf.nn.bias_add(outputs, masked_bias)

        return outputs


class MaskedDense(tf.keras.layers.Layer):
    """
    An efficient TensorFlow implementation of a masked dense (fully connected) layer.
    """

    def __init__(self, units, **kwargs):
        super(MaskedDense, self).__init__(**kwargs)
        self.units = units
        self.dense = tf.keras.layers.Dense(units=units, use_bias=True, trainable=False, kernel_initializer=MEInitializer())
        self.built = False

    def build(self, input_shape):
        # Ensures this layer's `build` method is properly called by Keras
        if not self.built:
            self.dense.build(input_shape)
            self.mask = self.add_weight(shape=self.dense.kernel.shape,
                                        initializer='glorot_uniform',
                                        trainable=True,
                                        name='mask')
            if self.dense.use_bias:
                self.bias_mask = self.add_weight(shape=(self.dense.units,),
                                                 initializer='glorot_uniform',
                                                 trainable=True,
                                                 name='bias_mask')
            self.built = True
        super(MaskedDense, self).build(input_shape)  # Ensure to call the parent's `build` to set self.built = True

    def call(self, inputs, masks=None, training=None, verbose=False, md=False):
        # Use tf.stop_gradient to prevent gradients from flowing through the weights and biases
        kernel_no_gradient = tf.stop_gradient(self.dense.kernel)
        bias_no_gradient = tf.stop_gradient(self.dense.bias) if self.dense.use_bias else None

        # Apply sigmoid to the mask to get probabilities in the range (0, 1).
        if md:
            mask_probs = tf.exp(self.mask)
            bias_mask_probs = tf.exp(self.bias_mask) if self.dense.use_bias else None
        else:
            mask_probs = tf.sigmoid(self.mask)
            bias_mask_probs = tf.sigmoid(self.bias_mask) if self.dense.use_bias else None
        if training:
            if masks is not None and masks is not -1:
                mask = masks[self.name]
                binary_mask = bernoulli_dummy(mask_probs, mask)
            else:
                # Create binary masks for weights and biases using sampled probabilities
                binary_mask = bernoulli(mask_probs)

            masked_kernel = kernel_no_gradient * binary_mask

            if self.dense.use_bias:
                if masks is not None and masks is not -1:
                    bias_mask = masks[self.name + '_bias_mask']
                    binary_bias_mask = bernoulli_dummy(bias_mask_probs, bias_mask)
                else:
                    # Create binary masks for weights and biases using sampled probabilities
                    binary_bias_mask = bernoulli(bias_mask_probs)

                masked_bias = bias_no_gradient * binary_bias_mask
        else:
            # Use probabilities as masks directly without sampling in non-training mode
            masked_kernel = kernel_no_gradient * mask_probs
            masked_bias = bias_no_gradient * bias_mask_probs if self.dense.use_bias else None

        # Compute the outputs using the masked kernel and bias
        outputs = tf.matmul(inputs, masked_kernel)

        if self.dense.use_bias:
            outputs = tf.nn.bias_add(outputs, masked_bias)

        return outputs


class LeNet5Masked(tf.keras.Model):
    def __init__(self, **kwargs):
        super(LeNet5Masked, self).__init__(**kwargs)
        self.conv1 = MaskedConv2D(filters=6, kernel_size=(5, 5), input_shape=(28, 28, 1), name="masked_conv2d")
        self.pool1 = tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=2)
        self.conv2 = MaskedConv2D(filters=16, kernel_size=(5, 5), name="masked_conv2d_1", padding='valid')
        self.pool2 = tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=2)
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = MaskedDense(units=120, name="masked_dense")
        self.dense2 = MaskedDense(units=84, name="masked_dense_1")
        self.output_layer = MaskedDense(units=10, name="masked_dense_2")

    def call(self, inputs, masks=None, verbose=False):
        x = self.conv1(inputs, masks=masks, verbose=False)
        x = tf.nn.relu(x)
        x = self.pool1(x)
        x = self.conv2(x, masks=masks, verbose=False)
        x = tf.nn.relu(x)
        x = self.pool2(x)
        #x = self.flatten(x)

        x = tf.reshape(x, [tf.shape(x)[0], -1])
        x = self.dense1(x, masks=masks, verbose=False)
        x = tf.nn.relu(x)
        x = self.dense2(x, masks=masks, verbose=False)
        x = tf.nn.relu(x)
        x = self.output_layer(x, masks=masks, verbose=False)
        return tf.nn.softmax(x)

class ResidualBlock(tf.keras.Model):
    def __init__(self, filters, kernel_size=(3, 3), stride=1, block_id=""):
        super().__init__()
        self.conv1 = MaskedConv2D(filters, kernel_size, strides=stride, name=f"masked_conv2d_{block_id}_1")
        self.bn1 = tf.keras.layers.BatchNormalization(name=f"batch_norm_{block_id}_1")
        self.conv2 = MaskedConv2D(filters, kernel_size, name=f"masked_conv2d_{block_id}_2")
        self.bn2 = tf.keras.layers.BatchNormalization(name=f"batch_norm_{block_id}_2")

        if stride != 1:
            self.downsample = tf.keras.Sequential([
                MaskedConv2D(filters, (1, 1), strides=stride, name=f"masked_conv2d_{block_id}_downsample"),
                tf.keras.layers.BatchNormalization(name=f"batch_norm_{block_id}_downsample")
            ], name=f"downsample_{block_id}")
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv2(x)
        x = self.bn2(x, training=training)

        x += self.downsample(inputs)
        x = tf.nn.relu(x)
        return x

class ResNet18Masked(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(ResNet18Masked, self).__init__()
        self.init_conv = MaskedConv2D(filters=64, kernel_size=(7, 7), strides=2,
                                      name="initial_masked_conv2d")
        self.init_bn = tf.keras.layers.BatchNormalization(name="initial_batch_norm")
        self.init_relu = tf.keras.layers.ReLU(name="initial_relu")
        self.init_maxpool = tf.keras.layers.MaxPool2D(pool_size=(3, 3), strides=2,
                                                      name="initial_max_pool")

        # Define all layers for each residual block explicitly without using tf.keras.Sequential
        # Layer 1
        self.conv1_1 = MaskedConv2D(64, (3, 3), name="conv1_1")
        self.bn1_1 = tf.keras.layers.BatchNormalization(name="bn1_1")
        self.conv1_2 = MaskedConv2D(64, (3, 3), name="conv1_2")
        self.bn1_2 = tf.keras.layers.BatchNormalization(name="bn1_2")

        # Layer 2
        self.conv2_1 = MaskedConv2D(128, (3, 3), strides=2, name="conv2_1")
        self.bn2_1 = tf.keras.layers.BatchNormalization(name="bn2_1")
        self.conv2_2 = MaskedConv2D(128, (3, 3), name="conv2_2")
        self.bn2_2 = tf.keras.layers.BatchNormalization(name="bn2_2")
        self.downsample2_conv = MaskedConv2D(128, (1, 1), strides=2, name="downsample2_conv")
        self.downsample2_bn = tf.keras.layers.BatchNormalization(name="downsample2_bn")

        # Layer 3
        self.conv3_1 = MaskedConv2D(256, (3, 3), strides=2, name="conv3_1")
        self.bn3_1 = tf.keras.layers.BatchNormalization(name="bn3_1")
        self.conv3_2 = MaskedConv2D(256, (3, 3), name="conv3_2")
        self.bn3_2 = tf.keras.layers.BatchNormalization(name="bn3_2")
        self.downsample3_conv = MaskedConv2D(256, (1, 1), strides=2, name="downsample3_conv")
        self.downsample3_bn = tf.keras.layers.BatchNormalization(name="downsample3_bn")

        # Layer 4
        self.conv4_1 = MaskedConv2D(512, (3, 3), strides=2, name="conv4_1")
        self.bn4_1 = tf.keras.layers.BatchNormalization(name="bn4_1")
        self.conv4_2 = MaskedConv2D(512, (3, 3), name="conv4_2")
        self.bn4_2 = tf.keras.layers.BatchNormalization(name="bn4_2")
        self.downsample4_conv = MaskedConv2D(512, (1, 1), strides=2, name="downsample4_conv")
        self.downsample4_bn = tf.keras.layers.BatchNormalization(name="downsample4_bn")

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D(name="global_avg_pool")
        self.fc = MaskedDense(num_classes, name="masked_dense_output")

    def call(self, inputs, training=False):
        x = self.init_conv(inputs)
        x = self.init_bn(x, training=training)
        x = self.init_relu(x)
        x = self.init_maxpool(x)

        x = self.apply_block(x, self.conv1_1, self.bn1_1, self.conv1_2, self.bn1_2, training,
                             lambda x: x)  # no downsampling here
        x = self.apply_block(x, self.conv2_1, self.bn2_1, self.conv2_2, self.bn2_2, training, self.downsample2_conv,
                             self.downsample2_bn)
        x = self.apply_block(x, self.conv3_1, self.bn3_1, self.conv3_2, self.bn3_2, training, self.downsample3_conv,
                             self.downsample3_bn)
        x = self.apply_block(x, self.conv4_1, self.bn4_1, self.conv4_2, self.bn4_2, training, self.downsample4_conv,
                             self.downsample4_bn)

        x = self.avgpool(x)
        x = self.fc(x)
        return tf.nn.softmax(x)

    def apply_block(self, inputs, conv1, bn1, conv2, bn2, training, downsample_conv=None, downsample_bn=None):
        identity = inputs
        if downsample_conv is not None and downsample_bn is not None:
            identity = downsample_conv(identity)
            identity = downsample_bn(identity, training=training)

        x = conv1(inputs)
        x = bn1(x, training=training)
        x = tf.nn.relu(x)
        x = conv2(x)
        x = bn2(x, training=training)
        x += identity
        x = tf.nn.relu(x)
        return x


class ResNet18MaskedNew(tf.keras.Model):
    def __init__(self, num_classes=10, **kwargs):
        super(ResNet18MaskedNew, self).__init__(**kwargs)

        self.conv1 = MaskedConv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same', name="masked_conv2d_1")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), strides=2, padding='same')

        # Residual block 1 (2 layers)
        self.conv2_x = self._make_residual_block(64, 2, name="conv2_x")
        # Residual block 2 (2 layers)
        self.conv3_x = self._make_residual_block(128, 2, strides=2, name="conv3_x")
        # Residual block 3 (2 layers)
        self.conv4_x = self._make_residual_block(256, 2, strides=2, name="conv4_x")
        # Residual block 4 (2 layers)
        self.conv5_x = self._make_residual_block(512, 2, strides=2, name="conv5_x")

        self.avgpool = tf.keras.layers.GlobalAvgPool2D()
        self.flatten = tf.keras.layers.Flatten(name="flatten")  # Add Flatten layer
        self.fc = MaskedDense(units=num_classes, name="masked_dense_output")

    def _make_residual_block(self, filters, blocks, strides=1, name=None):
        res_blocks = []
        # First block in each residual block (may need to adjust stride)
        res_blocks.append(ResidualBlockMasked(filters, strides=strides))
        # Remaining blocks
        for _ in range(1, blocks):
            res_blocks.append(ResidualBlockMasked(filters, strides=1))

        return res_blocks  # Return list of residual blocks instead of using Sequential

    def call(self, inputs, masks=None, verbose=False):
        x = self.conv1(inputs, masks=masks)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.pool1(x)

        for block in self.conv2_x:
            x = block(x, masks=masks)

        for block in self.conv3_x:
            x = block(x, masks=masks)

        for block in self.conv4_x:
            x = block(x, masks=masks)

        for block in self.conv5_x:
            x = block(x, masks=masks)

        x = self.avgpool(x)
        x = tf.keras.layers.Flatten()(x)
        x = self.fc(x, masks=masks)

        return tf.nn.softmax(x)


class ResidualBlockMasked(tf.keras.Model):
    def __init__(self, filters, strides=1, **kwargs):
        super(ResidualBlockMasked, self).__init__(**kwargs)
        self.conv1 = MaskedConv2D(filters=filters, kernel_size=(3, 3), strides=strides, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = MaskedConv2D(filters=filters, kernel_size=(3, 3), strides=1, padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()

        self.shortcut = tf.keras.Sequential()
        if strides != 1:
            # Apply a 1x1 convolution to match dimensions for shortcut
            self.shortcut.add(MaskedConv2D(filters=filters, kernel_size=(1, 1), strides=strides))
            self.shortcut.add(tf.keras.layers.BatchNormalization())

    def call(self, inputs, masks=None, verbose=False):
        residual = inputs
        if self.shortcut.layers:
            for layer in self.shortcut.layers:
                residual = layer(residual)  # No masks for shortcut

        x = self.conv1(inputs, masks=masks)
        x = self.bn1(x)
        x = tf.nn.relu(x)

        x = self.conv2(x, masks=masks)
        x = self.bn2(x)

        x += residual  # Add the shortcut connection
        return tf.nn.relu(x)

def residual_block(inputs, filters, strides=1, block_name="block"):
    shortcut = inputs

    x = MaskedConv2D(filters, kernel_size=(3, 3), strides=strides, name=f"{block_name}_conv1")(inputs)
    x = tf.keras.layers.BatchNormalization(name=f"{block_name}_bn1")(x)
    x = tf.keras.layers.ReLU()(x)

    x = MaskedConv2D(filters, kernel_size=(3, 3), strides=1, name=f"{block_name}_conv2")(x)
    x = tf.keras.layers.BatchNormalization(name=f"{block_name}_bn2")(x)

    if strides != 1 or inputs.shape[-1] != filters:
        shortcut = MaskedConv2D(filters, kernel_size=(1, 1), strides=strides, name=f"{block_name}_shortcut")(inputs)
        shortcut = tf.keras.layers.BatchNormalization(name=f"{block_name}_shortcut_bn")(shortcut)

    x = tf.keras.layers.add([x, shortcut])
    x = tf.keras.layers.ReLU()(x)
    return x

def ResNet18Maskedn2(input_shape=(32, 32, 3), num_classes=10):
    inputs = tf.keras.Input(shape=input_shape)

    # Initial Conv Layer
    x = MaskedConv2D(filters=64, kernel_size=(7, 7), strides=2, padding="same", name="initial_conv")(inputs)
    x = tf.keras.layers.BatchNormalization(name="initial_bn")(x)
    x = tf.keras.layers.ReLU(name="initial_relu")(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding="same")(x)

    # Residual Blocks
    x = residual_block(x, filters=64, strides=1, block_name="block1")
    x = residual_block(x, filters=64, strides=1, block_name="block2")

    x = residual_block(x, filters=128, strides=2, block_name="block3")
    x = residual_block(x, filters=128, strides=1, block_name="block4")

    x = residual_block(x, filters=256, strides=2, block_name="block5")
    x = residual_block(x, filters=256, strides=1, block_name="block6")

    x = residual_block(x, filters=512, strides=2, block_name="block7")
    x = residual_block(x, filters=512, strides=1, block_name="block8")

    # Average Pooling and Final Dense Layer
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = MaskedDense(units=num_classes, name="final_dense")(x)
    outputs = tf.nn.softmax(outputs)

    # Create Model
    model = tf.keras.Model(inputs, outputs, name="ResNet18Masked")
    return model

class Mask6CNN(tf.keras.Model):
    """
    6Conv model studied in
    https://proceedings.neurips.cc/paper/2019/file/1113d7a76ffceca1bb350bfe145467c6-Paper.pdf for Cifar-10.
    """

    def __init__(self, init='ME_init', activation='relu', **kwargs):
        super(Mask6CNN, self).__init__(**kwargs)
        self.activation = tf.nn.relu
        self.init = init

        # Define masked convolutional layers with custom initializers
        self.conv1 = MaskedConv2D(64, (3, 3), strides=1, padding='same', name="conv1")
        self.conv2 = MaskedConv2D(64, (3, 3), strides=1, padding='same', name="conv2")
        self.conv3 = MaskedConv2D(128, (3, 3), strides=1, padding='same', name="conv3")
        self.conv4 = MaskedConv2D(128, (3, 3), strides=1, padding='same', name="conv4")
        self.conv5 = MaskedConv2D(256, (3, 3), strides=1, padding='same', name="conv5")
        self.conv6 = MaskedConv2D(256, (3, 3), strides=1, padding='same', name="conv6")

        # Define masked dense layers
        self.dense1 = MaskedDense(256, name="dense1")
        self.dense2 = MaskedDense(256, name="dense2")
        self.dense3 = MaskedDense(10, name="dense3")  # Output layer

        # Pooling layer
        self.pool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)

    def call(self, inputs, masks=None, training=False):
        # Apply custom masked convolutions and activations
        x = self.activation(self.conv1(inputs, masks=masks))
        x = self.pool(self.activation(self.conv2(x, masks=masks)))

        x = self.activation(self.conv3(x, masks=masks))
        x = self.pool(self.activation(self.conv4(x, masks=masks)))

        x = self.activation(self.conv5(x, masks=masks))
        x = self.pool(self.activation(self.conv6(x, masks=masks)))

        # Flatten the output for the dense layers
        x = tf.reshape(x, [tf.shape(x)[0], -1])

        # Apply masked dense layers
        x = self.activation(self.dense1(x, masks=masks))
        x = self.activation(self.dense2(x, masks=masks))
        x = self.dense3(x, masks=masks)  # No activation in the final layer

        return tf.nn.softmax(x)

class Mask4CNN(tf.keras.Model):
    """
    4Conv model studied in
    https://proceedings.neurips.cc/paper/2019/file/1113d7a76ffceca1bb350bfe145467c6-Paper.pdf for Cifar-10.
    """
    def __init__(self, init='ME_init', activation='relu', **kwargs):
        super(Mask4CNN, self).__init__(**kwargs)
        self.activation = tf.nn.relu
        self.init = init

        # Define masked convolutional layers with custom initializers
        self.conv1 = MaskedConv2D(64, (3, 3), strides=1, padding='same', name="conv1")
        self.conv2 = MaskedConv2D(64, (3, 3), strides=1, padding='same', name="conv2")
        self.conv3 = MaskedConv2D(128, (3, 3), strides=1, padding='same', name="conv3")
        self.conv4 = MaskedConv2D(128, (3, 3), strides=1, padding='same', name="conv4")

        # Define masked dense layers
        self.dense1 = MaskedDense(256, name="dense1")
        self.dense2 = MaskedDense(256, name="dense2")
        self.dense3 = MaskedDense(10, name="dense3")  # Output layer

        # Pooling layer
        self.pool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)

    def call(self, inputs, masks=None, training=False):
        # Apply custom masked convolutions and activations
        x = self.activation(self.conv1(inputs, masks=masks))
        x = self.pool(self.activation(self.conv2(x, masks=masks)))

        x = self.activation(self.conv3(x, masks=masks))
        x = self.pool(self.activation(self.conv4(x, masks=masks)))

        # Flatten the output for the dense layers
        x = tf.reshape(x, [tf.shape(x)[0], -1])

        # Apply masked dense layers
        x = self.activation(self.dense1(x, masks=masks))
        x = self.activation(self.dense2(x, masks=masks))
        x = self.dense3(x, masks=masks)  # No activation in the final layer

        return tf.nn.softmax(x)