import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
colors = sns.color_palette()
from matplotlib import rcParams
params = {'axes.labelsize': 28,
          'axes.grid': True,
          'axes.linewidth': 1.6,
          'axes.titlepad': 20,
          'axes.xmargin': 0.05,
          'axes.ymargin': 0.05,
          'grid.alpha': 0.2,
          'grid.color': '#666666',
          'grid.linestyle': '-.',
          'legend.fontsize': 16,
          'legend.loc': 'lower right',
          'xtick.labelsize': 28,
          'xtick.major.width': 1.6,
          'xtick.major.size': 10,
          'xtick.minor.width': 1.0,
          'xtick.minor.size': 4,
          'ytick.labelsize': 28,
          'ytick.major.width': 1.6,
          'ytick.major.size': 10,
          'ytick.minor.width': 1.0,
          'ytick.minor.size': 4,
          'text.usetex': True,
          'figure.figsize': [12, 8],
          'font.size': 32.0, 
          'lines.markersize': np.sqrt(20) * 2.5,
          'figure.autolayout': True
          }
rcParams.update(params)

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from BesselConv2d import BesselConv2d

# Checking available devices
print('CPU available(s):', tf.config.list_physical_devices('CPU'))
print('GPU available(s):', tf.config.list_physical_devices('GPU'))

# Uncomment to avoid tensorflow to use GPU
#tf.config.set_visible_devices([], 'GPU')

# Loading MNIST from keras
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# Normalize pixel values to be between -1 and 1, and of shape (n_data, x, y, n_channels)
train_images = (train_images[:,:,:,tf.newaxis] / 255.0)*2.-1
test_images = (test_images[:,:,:,tf.newaxis] / 255.0)

# Generating the testing set
angles = np.random.uniform(0, 360, size=(10000))
test_rotated = tfa.image.rotate(test_images, angles=angles, fill_mode='constant', fill_value=0)
test_rotated = test_rotated[:,:,:,:]*2.-1

# Building the model
model = keras.models.Sequential()
m_max = 10 ; j_max = 5 ; k = 15 ; n_filters = 32
model.add(BesselConv2d(m_max=m_max, j_max=j_max, k=k, n_out=n_filters, 
                       strides=1, padding='VALID', activation='relu', name='Layer1'))
model.add(keras.layers.BatchNormalization())
m_max = 10 ; j_max = 5 ; k = 7 ; n_filters = 32
model.add(BesselConv2d(m_max=m_max, j_max=j_max, k=k, n_out=n_filters, 
                       strides=1, padding='VALID', activation='relu', name='Layer2'))
model.add(keras.layers.BatchNormalization())
m_max = 5 ; j_max = 3 ; k = 7 ; n_filters = 32
model.add(BesselConv2d(m_max=m_max, j_max=j_max, k=k, n_out=n_filters, 
                       strides=1, padding='VALID', activation='relu', name='Layer3'))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(10))

model.build(input_shape=(None, 28, 28, 1))

model.summary()

# Training the model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=1, 
                    validation_data=(test_rotated, test_labels),
                    verbose=1)

# Plot learning curves
plt.plot(history.history['val_accuracy'], c=colors[0], label='Validation accuracy')
plt.plot(history.history['accuracy'], c=colors[0], linestyle='--', label='Training accuracy')
plt.legend()
plt.ylim([0,1])
plt.savefig('learning_curves.png')