import numpy as np
from tqdm import tqdm

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import backend as K

# Checking available devices
print('CPU available(s):', tf.config.list_physical_devices('CPU'))
print('GPU available(s):', tf.config.list_physical_devices('GPU'))

from BesselConv2d import BesselConv2d
from GroupConv2d import GroupConv2d


# Load the data
# -------------

from sklearn.model_selection import train_test_split

# inputs
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = (train_images[:,:,:,tf.newaxis] / 255.0), (test_images[:,:,:,tf.newaxis] / 255.0)

angles = np.random.uniform(0, 2.*np.pi, size=(10000))
test_rotated = tfa.image.rotate(test_images, angles=angles, fill_mode='constant', fill_value=0)

angles = np.random.uniform(0, 2.*np.pi, size=(60000))
train_rotated = tfa.image.rotate(train_images, angles=angles, fill_mode='constant', fill_value=0)


# Model CNN
# ---------
model_cnn = keras.models.Sequential(name='AllGConv')
model_cnn.add(keras.layers.Conv2D(32, (3, 3), activation='relu',name='Conv1', input_shape=(28, 28, 1)))
model_cnn.add(keras.layers.Conv2D(32, (5, 5), activation='relu',name='Conv2'))
model_cnn.add(keras.layers.Conv2D(32, (5, 5), activation='relu',name='Conv3'))
model_cnn.add(keras.layers.Conv2D(32, (5, 5), activation='relu',name='Conv4'))
model_cnn.add(keras.layers.MaxPooling2D((2, 2), name='Maxpool1'))
model_cnn.add(keras.layers.Conv2D(64, (7, 7), activation='relu',name='Conv5'))
model_cnn.add(keras.layers.Flatten(name='Flatten'))
model_cnn.add(keras.layers.Dense(10, name='Output'))
model_cnn.summary()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
model_cnn.compile(optimizer=optimizer,
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

model_cnn.fit(train_rotated, train_labels, epochs=15, 
              validation_data=(test_rotated, test_labels))


# Model G-CNN
# -----------
model_gcnn = keras.models.Sequential(name='AllGConv')
model_gcnn.add(GroupConv2d(k=3, n_out=16, h_input='Z2', h_output='C4', activation='relu', name='GConv1', input_shape=(28, 28, 1)))
model_gcnn.add(GroupConv2d(k=5, n_out=16, h_input='C4', h_output='C4', activation='relu', name='GConv2'))
model_gcnn.add(GroupConv2d(k=5, n_out=16, h_input='C4', h_output='C4', activation='relu', name='GConv3'))
model_gcnn.add(GroupConv2d(k=5, n_out=16, h_input='C4', h_output='C4', activation='relu', name='GConv4'))
model_gcnn.add(keras.layers.MaxPooling2D((2, 2), name='Maxpool1'))
model_gcnn.add(GroupConv2d(k=7, n_out=32, h_input='C4', h_output='C4', activation='relu', name='GConv5'))
model_gcnn.add(keras.layers.Reshape((1,1,32,4)))
model_gcnn.add(keras.layers.Lambda(lambda x: K.sum(x, axis=4), input_shape=(1,1,32,4)))
model_gcnn.add(keras.layers.Flatten(name='Flatten'))
model_gcnn.add(keras.layers.Dense(10, name='Output'))
model_gcnn.summary()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
model_gcnn.compile(optimizer=optimizer,
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   metrics=['accuracy'])

model_gcnn.fit(train_rotated, train_labels, epochs=15, 
               validation_data=(test_rotated, test_labels))


# Model B-CNN
# -----------
model_bcnn = keras.models.Sequential(name='AllBConv')
m_max = 11 ; j_max = 9 ; k = 3 ; n_out = 16
model_bcnn.add(BesselConv2d(m_max,j_max,k,n_out,padding='VALID', activation='tanh', name='BConv1', input_shape=(28, 28, 1)))
m_max = 9 ; j_max = 7 ; k = 5 ; n_out = 16
model_bcnn.add(BesselConv2d(m_max,j_max,k,n_out,padding='VALID', activation='tanh', name='BConv2'))
m_max = 9 ; j_max = 7 ; k = 5 ; n_out = 16
model_bcnn.add(BesselConv2d(m_max,j_max,k,n_out,padding='VALID', activation='tanh', name='BConv3'))
m_max = 9 ; j_max = 7 ; k = 5 ; n_out = 16
model_bcnn.add(BesselConv2d(m_max,j_max,k,n_out,padding='VALID', activation='tanh', name='BConv4'))
model_bcnn.add(keras.layers.MaxPooling2D((2, 2), name='Maxpool1'))
m_max = 7 ; j_max = 5 ; k = 7 ; n_out = 32
model_bcnn.add(BesselConv2d(m_max,j_max,k,n_out,padding='VALID', activation='tanh', name='BConv5'))
model_bcnn.add(keras.layers.Flatten(name='Flatten'))
model_bcnn.add(keras.layers.Dense(10, name='output'))
model_bcnn.summary()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
model_bcnn.compile(optimizer=optimizer,
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   metrics=['accuracy'])

model_bcnn.fit(train_rotated, train_labels, epochs=15, 
               validation_data=(test_rotated, test_labels))


"""
B-CNN
"""
difference = []
for id_image in tqdm(range(10000)):
    
    angle = np.random.rand() * 2. * np.pi
    
    x = tfa.image.rotate(test_images[tf.newaxis,id_image,:,:,:], angles=0, fill_mode='constant', fill_value=0, interpolation='bilinear')
    x_r = tfa.image.rotate(x, angles=angle, fill_mode='constant', fill_value=0, interpolation='bilinear')

    layer = 3

    OutFunc = K.function([model_bcnn.input], [model_bcnn.layers[layer].output])
    out_B_CNN = OutFunc(x)[0]
    out_B_CNN_rot = OutFunc(x_r)[0]
    
    local_diff = []
    local_sum = []
    for j in range(16):
    
        feature_map = tfa.image.rotate(out_B_CNN[0,:,:,j], angles=angle, fill_mode='constant', fill_value=0)
        masked_fmap1 = tfa.image.rotate(feature_map, angles=-angle, fill_mode='constant', fill_value=0)

        masked_fmap2 = tfa.image.rotate(out_B_CNN_rot[0,:,:,j], angles=-angle, fill_mode='constant', fill_value=0)

        min_val = np.min([np.min(masked_fmap1), np.min(masked_fmap2)])
        masked_fmap1 -= min_val
        masked_fmap2 -= min_val
        max_val = np.max([np.max(masked_fmap1), np.max(masked_fmap2)])
        if max_val != 0:
            masked_fmap1 /= max_val
            masked_fmap2 /= max_val

        local_diff.append(np.sum(np.abs(masked_fmap1 - masked_fmap2))/(14*14)*100)
        local_sum.append(np.sum(np.abs(masked_fmap1)))
    
    for j in range(4):
        max_idx = np.argmax(local_sum)
        local_sum.pop(max_idx)
        difference.append(local_diff[max_idx])
    
print('B-CNN:', np.mean(difference), np.std(difference))


"""
CNN
"""
difference = []
for id_image in tqdm(range(10000)):

    angle = np.random.rand() * 2. * np.pi
    
    x = tfa.image.rotate(test_images[tf.newaxis,id_image,:,:,:], angles=0, fill_mode='constant', fill_value=0, interpolation='bilinear')
    x_r = tfa.image.rotate(x, angles=angle, fill_mode='constant', fill_value=0, interpolation='bilinear')

    layer = 3

    OutFunc = K.function([model_cnn.input], [model_cnn.layers[layer].output])
    out_B_CNN = OutFunc(x)[0]
    out_B_CNN_rot = OutFunc(x_r)[0]
    
    local_diff = []
    local_sum = []
    for j in range(32):
    
        feature_map = tfa.image.rotate(out_B_CNN[0,:,:,j], angles=angle, fill_mode='constant', fill_value=0)
        masked_fmap1 = tfa.image.rotate(feature_map, angles=-angle, fill_mode='constant', fill_value=0)

        masked_fmap2 = tfa.image.rotate(out_B_CNN_rot[0,:,:,j], angles=-angle, fill_mode='constant', fill_value=0)

        min_val = np.min([np.min(masked_fmap1), np.min(masked_fmap2)])
        masked_fmap1 -= min_val
        masked_fmap2 -= min_val
        max_val = np.max([np.max(masked_fmap1), np.max(masked_fmap2)])
        if max_val != 0:
            masked_fmap1 /= max_val
            masked_fmap2 /= max_val
            
        local_diff.append(np.sum(np.abs(masked_fmap1 - masked_fmap2))/(14*14)*100)
        local_sum.append(np.sum(np.abs(masked_fmap1)))
    
    for j in range(4):
        max_idx = np.argmax(local_sum)
        local_sum.pop(max_idx)
        difference.append(local_diff[max_idx])
    
print('CNN:', np.mean(difference), np.std(difference))


"""
G-CNN
"""
difference = []
for id_image in tqdm(range(10000)):

    angle = np.random.rand() * 2. * np.pi
    
    x = tfa.image.rotate(test_images[tf.newaxis,id_image,:,:,:], angles=0, fill_mode='constant', fill_value=0, interpolation='bilinear')
    x_r = tfa.image.rotate(x, angles=angle, fill_mode='constant', fill_value=0, interpolation='bilinear')

    layer = 3

    OutFunc = K.function([model_gcnn.input], [model_gcnn.layers[layer].output])
    out_B_CNN = keras.layers.Lambda(lambda x: K.sum(x, axis=4), input_shape=(14,14,16,4))(keras.layers.Reshape((14,14,16,4))(OutFunc(x)))
    out_B_CNN_rot = keras.layers.Lambda(lambda x: K.sum(x, axis=4), input_shape=(14,14,16,4))(keras.layers.Reshape((14,14,16,4))(OutFunc(x_r)))
    
    local_diff = []
    local_sum = []
    for j in range(16):
    
        feature_map = tfa.image.rotate(out_B_CNN[0,:,:,j], angles=angle, fill_mode='constant', fill_value=0)
        masked_fmap1 = tfa.image.rotate(feature_map, angles=-angle, fill_mode='constant', fill_value=0)

        masked_fmap2 = tfa.image.rotate(out_B_CNN_rot[0,:,:,j], angles=-angle, fill_mode='constant', fill_value=0)

        min_val = np.min([np.min(masked_fmap1), np.min(masked_fmap2)])
        masked_fmap1 -= min_val
        masked_fmap2 -= min_val
        max_val = np.max([np.max(masked_fmap1), np.max(masked_fmap2)])
        if max_val != 0:
            masked_fmap1 /= max_val
            masked_fmap2 /= max_val

        local_diff.append(np.sum(np.abs(masked_fmap1 - masked_fmap2))/(14*14)*100)
        local_sum.append(np.sum(np.abs(masked_fmap1)))
    
    for j in range(4):
        max_idx = np.argmax(local_sum)
        local_sum.pop(max_idx)
        difference.append(local_diff[max_idx])
    
print('G-CNN:', np.mean(difference), np.std(difference))