# -*- coding: utf-8 -*-
"""
Created on Wed Apr 17 16:10:46 2024

@author: CatC_
"""



import tensorflow as tf
import larq as lq
import numpy as np
from tensorflow.keras.constraints import MinMaxNorm
import json

#import scipy.io






(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape((60000, 784,1))
test_images = test_images.reshape((10000,  784,1))
 

# Normalize pixel values to be between -1 and 1
train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1


from tensorflow.keras.constraints import Constraint
import tensorflow as tf

class ClipWeights(Constraint):                                      #to make the weights between -1 and 1       
    """Clips the weights to a specified min and max value."""
    
    def __init__(self, min_value=-1.0, max_value=1.0):
        self.min_value = min_value
        self.max_value = max_value
    
    def __call__(self, weights):
        return tf.clip_by_value(weights, self.min_value, self.max_value)
    
    def get_config(self):
        return {'min_value': self.min_value, 'max_value': self.max_value}

weight_constraint = ClipWeights(min_value=-1.0, max_value=1.0)


initializer = tf.keras.initializers.GlorotUniform(seed=123)



# Define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(784,1)),
    
    #First hidden layer
    lq.layers.QuantDense(350,
                          kernel_quantizer=lq.quantizers.SteTern( threshold_value=0.03, ternary_weight_networks=False, clip_value=1.0),
                         use_bias=False, kernel_initializer=initializer), 
    
    tf.keras.layers.BatchNormalization(scale=True,center=True),
    
    #Second hidden layer with binarized weights
    lq.layers.QuantDense(350,
                         input_quantizer='ste_sign',
                          kernel_quantizer=lq.quantizers.SteTern( threshold_value=0.03, ternary_weight_networks=False, clip_value=1.0),
                         use_bias=False,kernel_initializer=initializer),
    
    tf.keras.layers.BatchNormalization(scale=True,center=True),
    
        #Third hidden layer with binarized weights
    lq.layers.QuantDense(350,
                         input_quantizer='ste_sign',
                          kernel_quantizer=lq.quantizers.SteTern( threshold_value=0.03, ternary_weight_networks=False, clip_value=1.0),
                         use_bias=False,kernel_initializer=initializer),
    
    tf.keras.layers.BatchNormalization(scale=True,center=True),

    
    lq.layers.QuantDense(10,  
                         input_quantizer='ste_sign',
                           kernel_quantizer=lq.quantizers.SteTern( threshold_value=0.00, ternary_weight_networks=False, clip_value=1.0),
                         use_bias=True,  # Bias can be used here since it's the output layer
                         activation='softmax',kernel_initializer=initializer)
])

lq.models.summary(model)


lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-4,
    decay_steps=10000,
    decay_rate=0.9)

# Use this schedule in the optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, batch_size=256, epochs=300,validation_split=0.1)

test_loss, test_acc = model.evaluate(test_images, test_labels)


print(f"Test accuracy {test_acc * 100:.2f} %")

A=model.layers
k=0

T=[]
for i in range(len(A)):
  f = tf.keras.backend.function(model.input,A[i].output)
  t=f(train_images[k].reshape( 1,*train_images[k].shape))
  T.append(t)
  
  
model.save("full_precision_model_debug.h5")  # save full precision latent weights
fp_weights = model.get_weights()  # get latent weights

with lq.context.quantized_scope(True):
    model.save("binary_3x50_debug.h5")  # save binarized weights
    weights = model.get_weights()  # get binarized weights
    print(weights)
    
    
    
    
N=4 # BNN layers
W=weights

A_mat=[]
b=[]


for i in range(N-1):
    gamma=W[i*5+1]
    beta=W[i*5+2]
    mu=W[i*5+3]
    sigma=W[i*5+4]
    
    sigma_kai=(1/np.sqrt(0.001+sigma))
    A_t=np.matmul(np.diag(gamma*sigma_kai),W[i*5].transpose())
    A_mat.append(A_t)
    b_t=-gamma*sigma_kai*mu+beta
    b.append(b_t)

# Last Layer
i=N-1;
A_mat.append(W[5*i].transpose())
b.append(W[5*i+1])


norm=np.linalg.norm



#JSON-Write


Img_data=[];
for k in range(500):
    Img_data.append(train_images[k].reshape(784,1))
    


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

net_dict = {'A':A_mat,'b':b,'img':Img_data}

with open('net_data_test.txt', 'w') as json_file:
    json.dump(net_dict, json_file,cls=NumpyEncoder)