#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import time
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Input, AveragePooling2D
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.layers import LeakyReLU
import DonaldDuckDataset
import os
import DonaldDuckFunc

@tf.custom_gradient
def floor_func(y):
    def backward(dy):
        return dy
    return tf.floor(y * 10) / 10, backward

def floor_relu(y):
    return floor_func(tf.nn.relu(y))

@tf.custom_gradient
def step_func(y):
    def backward(dy):
        return dy
    return tf.cast(y>tf.zeros_like(y),dtype='float32')+tf.nn.relu(y), backward

def step_relu(y):
    return step_func(tf.nn.relu(y))

get_custom_objects().update({'step_relu':  Activation(step_relu)})
get_custom_objects().update({'floor_relu': Activation(floor_relu)})


def lr_schedule(epoch):
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    print('Learning rate: ', lr)
    return lr

class DonaldDuckModel():
    def __init__(
            self,
            dataset,
            batch_size=32,
            epochs=200,
            act_func='relu',
            kernel_size=(3,3),
            custom_act_flag=False,
            build_dir=True
    ):
        self.dataset=dataset
        self.custom_act_flag=custom_act_flag
        self.act_func=act_func
        self.batch_size = batch_size  
        self.epochs = epochs
        self.num_classes=self.dataset.num_classes
        self.input_shape=dataset.input_shape
        self.learning_rate=1e-3
        self.kernel_size=kernel_size
        self.opt=None
        self.loss=None
        self.metrics=None
        self.model=None
        self.importDataset()
        if build_dir:
            self.buildFolder()

    def setModel(
            self,
            model_path=None,
            weight_path=None
    ):
        if not (model_path is None):
            self.loadModel(model_path=model_path)
        else:
            self.setArchitecture()
            self.model.summary()
            if not (weight_path is None):
                self.loadWeights(weights_path=weight_path)

    def setArchitecture(self):
        return None

    def importDataset(self):
        (self.x_train,self.y_train),(self.x_test,self.y_test)=self.dataset.getData()

    def fitModel(self):
        self.model.compile(
            loss=self.loss,
            optimizer=self.opt,
            metrics=self.metrics
        )

        lr_scheduler = LearningRateScheduler(lr_schedule)
        lr_reducer = ReduceLROnPlateau(
            factor=np.sqrt(0.1),
            cooldown=0,
            patience=5,
            min_lr=0.5e-6
        )
        callbacks = [lr_reducer, lr_scheduler]

        self.model.fit(
            self.x_train,
            self.y_train,
            batch_size=self.batch_size,
            epochs=self.epochs,
            validation_data=(self.x_test, self.y_test),
            shuffle=True
        )

    def saveModel(self,modelName):
        self.model.save(self.saveModelPath+'/'+modelName)

    def saveModelWeights(self,weightName):
        self.model.save_weights(self.saveModelPath+'/'+weightName)

    def evaluateModel(self, x_test=None, y_test=None):
        if x_test==None:
            x_test=self.x_test
            y_test=self.y_test
        scores = self.model.evaluate(x_test, y_test, verbose=1)
        print('Test loss:', scores[0])
        print('Test accuracy:', scores[1])

    def load_model(self, model_path=None, weights_path=None):
        if not (model_path is None):
            self.loadModel(model_path)
        if not (weights_path is None):
            self.loadWeights(weights_path)

    def loadModel(self,model_path):
        self.model=load_model(model_path)

    def loadWeights(self,weights_path):
        self.model.load_weights(weights_path)

    def buildFolder(self):
        self.timeStamp = DonaldDuckFunc.getTimeStamp()
        self.saveImgPath = "images/" + self.timeStamp
        self.saveModelPath = "savedModels/" + self.timeStamp
        DonaldDuckFunc.buildDirs(self.saveModelPath)
        DonaldDuckFunc.buildDirs(self.saveImgPath)

    def predict_label(self, data):
        return np.argmax(self.model.predict(data), axis=1)

class VGG(DonaldDuckModel):
    def setArchitecture(self):
        model = Sequential()
        weight_decay = 0.0005

        model.add(Conv2D(64, (3, 3), padding='same',
                         input_shape=self.input_shape, kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.3))

        model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())

        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())

        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())

        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())

        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())
        model.add(Dropout(0.4))

        model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.5))

        model.add(Flatten())
        model.add(Dense(512, kernel_regularizer=regularizers.l2(weight_decay)))
        model.add(Activation(self.act_func))
        model.add(BatchNormalization())

        model.add(Dropout(0.5))
        model.add(Dense(self.num_classes))
        model.add(Activation('softmax'))

        lr_decay = 1e-6
        self.opt = SGD(
            lr=self.learning_rate,
            decay=lr_decay,
            momentum=0.9,
            nesterov=True
        )
        self.loss='categorical_crossentropy'
        self.metrics=['accuracy']

        model.summary()
        self.model= model

if __name__=='__main__':
    model=DonaldDuckModel()