import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import backend as k


class MySGDOptimizer(tf.keras.optimizers.Optimizer):
    #def _set_hyper(self):
    #    pass
    model=None
    def __init__(self,model, learning_rate=0.001, momentum=0.9, name="MySGDOptimizer", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name,**kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))  # handle lr=learning_rate
        self._set_hyper("decay", self._initial_decay)  #
        self._set_hyper("momentum", momentum)
        self.model=model

    def _create_slots(self, var_list):
        """For each model variable, create the optimizer variable associated with it.
        TensorFlow calls these optimizer variables "slots".
        For momentum optimization, we need one momentum slot per model variable.
        """
        for var in var_list:
            self.add_slot(var, "momentum")
            self.add_slot(var, "var2")

    @tf.function
    def _resource_apply_dense(self, grad, var):
        """Update the slots and perform one optimization step for one model variable
        """
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)  # handle learning rate decay
        #         momentum_var = self.get_slot(var, "momentum")
        #         momentum_hyper = self._get_hyper("momentum", var_dtype)
        #         momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)
        #var2=var
        #var.assign_sub(grad * lr_t)
        #var2.assign_sub(grad * lr_t)
        #var=var2

        var2 = self.get_slot(var, "var2")
        var2.assign(var)
        var2.assign_sub(grad * lr_t)
        var.assign(var2)
        #var.assign(model.trainable_weights[0])

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        }


(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
#import numpy as np
#ds_train_=np.copy(ds_train)
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

tf.config.run_functions_eagerly(True)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)


model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128,activation=None),
tf.keras.layers.Activation('relu'),
  tf.keras.layers.Dense(10)
])

print(tf.keras.layers.Dense(128,activation='relu').activation(2))
print(tf.keras.layers.Dense(128,activation=None).activation(2))
print(tf.keras.layers.Dense(10).activation(2))
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),#MySGDOptimizer(model,0.1),#tf.keras.optimizers.SGD(0.001),#tf.keras.optimizers.Adam(0.001),#MySGDOptimizer(0.001),#
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
#print(ds_train.prefetch(1)[0])
import numpy as np

print(type(ds_train))
print(list(map(lambda x: x[0], ds_train))[0][0].shape)
print(model.predict(list(map(lambda x: x[0], ds_train))[0]))
#tf.keras.backend.print_tensor(model.input)
#tf.print(model.input)
model.fit(
    ds_train,
    epochs=2,
    validation_data=ds_test,
)
#tf.keras.backend.print_tensor(model.input)