import tensorflow as tf

import numpy as np
import seaborn as sns

# visualization tools
import matplotlib.pyplot as plt
import gen_data
import pickle

d = 64


with open('./data/X_train.pickle', 'rb') as f:
    x_train = pickle.load(f)

with open('./data/Y_train.pickle', 'rb') as f:
    y_train = pickle.load(f)


def create_classical_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(4, [3, 3], input_shape=(d, d, 1), padding='same'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(tf.keras.layers.Conv2D(4, [3, 3], padding='same'))
    model.add(tf.keras.layers.UpSampling2D(size=(2, 2)))
    model.add(tf.keras.layers.GlobalAveragePooling2D())
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    return model


model = create_classical_model()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

model.summary()

checkpointer = tf.keras.callbacks.ModelCheckpoint('./model/brain.h5', verbose=1)

model.fit(x_train,
          y_train,
          batch_size=128,
          epochs=50,
          verbose=1,
          callbacks=[checkpointer])