# -*- coding: utf-8 -*-

import numpy as np

from scipy.ndimage.morphology import distance_transform_edt

## add keras packages
from keras.models import Model
from keras.layers import Input, Conv2D, Conv3D, MaxPooling2D, MaxPooling3D, Conv2DTranspose, Conv3DTranspose, BatchNormalization, Activation, AveragePooling3D, Add
from keras.layers.merge import concatenate
from keras.optimizers import Adam
from keras_radam import RAdam
from keras.utils.generic_utils import get_custom_objects
from keras import backend as K
import tensorflow as tf


# Define Swish activation function
class Swish(Activation):
    def __init__(self, activation, **kwargs):
        super(Swish,self).__init__(activation,**kwargs)
        self.__name__='swish'

def swish(x):
    return (K.sigmoid(x)*x)

get_custom_objects().update({'swish':Swish(swish)})



########################################################################################################################
''' HARMONICS '''
########################################################################################################################


def harmonic_loss3D_wrapper(dets_per_region=3, num_coefficients=81, weight_class=0.3, weight_pos=0.3, weight_shape=0.3, max_weight_scale=100):
    def harmonic_loss3D(y_true, y_pred): 
        # classification loss (MSLE)
        y_true_classify = tf.to_float(y_true[...,:dets_per_region])
        y_pred_classify = tf.to_float(y_pred[...,:dets_per_region])
        foregroundCount = K.sum(tf.to_float(y_true_classify>0))
        backgroundCount = K.sum(1-tf.to_float(y_true_classify>0))
        foregroundWeight = K.clip((backgroundCount+foregroundCount)/(foregroundCount), 1e-15, max_weight_scale)
        backgroundWeight = K.clip((backgroundCount+foregroundCount)/(backgroundCount), 1e-15, max_weight_scale)
        weighting_classify = foregroundWeight*y_true_classify + backgroundWeight*(1-y_true_classify)
        first_log = K.log(K.clip(y_pred_classify, K.epsilon(), None) + 1.)
        second_log = K.log(K.clip(y_true_classify, K.epsilon(), None) + 1.)
        loss_msle = K.square(first_log - second_log)
        loss_classify = K.sum(weighting_classify * loss_msle, axis=-1, keepdims=True)
        # positional loss (L1)
        weighting_positional = tf.transpose(y_true_classify, list(range(4,-1,-1)))
        weighting_positional = tf.gather(weighting_positional, [i//3 for i in range(dets_per_region*3)]) # for each region check if there actually is a cell (same weight for x_n, y_n, z_n)
        weighting_positional = tf.transpose(weighting_positional, list(range(4,-1,-1)))
        y_true_position = tf.to_float(y_true[...,dets_per_region:dets_per_region*4])
        y_pred_position = tf.to_float(y_pred[...,dets_per_region:dets_per_region*4])
        loss_positional = K.sum(weighting_positional*K.abs(y_true_position-y_pred_position), axis=-1, keepdims=True) 
        # shape loss (L1)
        weighting_shape = tf.transpose(y_true_classify, list(range(4,-1,-1)))
        weighting_shape = tf.gather(weighting_shape, [i//num_coefficients for i in range(dets_per_region*num_coefficients)]) # for each region check if there actually is a cell (same weight for each coefficient of detection n)
        weighting_shape = tf.transpose(weighting_shape, list(range(4,-1,-1)))
        y_true_shape = tf.to_float(y_true[...,dets_per_region*4:])
        y_pred_shape = tf.to_float(y_pred[...,dets_per_region*4:])
        loss_shape = K.sum(weighting_shape*K.abs(y_true_shape-y_pred_shape), axis=-1, keepdims=True) 
        # return weighted sum of all losses
        return K.mean(weight_class*loss_classify+weight_pos*loss_positional+weight_shape*loss_shape)
    return harmonic_loss3D






def harmonic_rcnn_3D(input_shape=(112,112,112), activation_fcn='relu', input_channels=1, dets_per_region=3, num_coefficients=81, weight_class=0.3, weight_pos=0.3, weight_shape=0.3, max_weight_scale=100, verbose=True, **kwargs):
    
    def VoxRes_module(input_layer, num_channels=64):
        vrn = BatchNormalization()(input_layer)
        vrn = Activation(activation_fcn)(vrn)
        vrn = Conv3D(num_channels, (3,3,3), strides=(1,1,1), padding='same')(vrn)
        vrn = BatchNormalization()(vrn)
        vrn = Activation(activation_fcn)(vrn)
        vrn = Conv3D(num_channels, (3,3,3), strides=(1,1,1), padding='same')(vrn)
        vrn = Add()([input_layer, vrn])
        return vrn
    
        
    inputs= Input((input_shape)+(input_channels,))
    
    l = Conv3D(32, (3,3,3), padding='same')(inputs)
    l = BatchNormalization()(l)
    l = Activation(activation_fcn)(l)
    
    l = Conv3D(32, (3,3,3), padding='same')(l)
    l = BatchNormalization()(l)
    l = Activation(activation_fcn)(l)
    
    l = Conv3D(64, (3,3,3), strides=(2,2,2), padding='same')(l)
    l = VoxRes_module(l, num_channels=64)
    l = VoxRes_module(l, num_channels=64)
    l = BatchNormalization()(l)    
    l = Activation(activation_fcn)(l)
    
    l = Conv3D(64, (3,3,3), strides=(2,2,2), padding='same')(l)
    l = VoxRes_module(l, num_channels=64)
    l = VoxRes_module(l, num_channels=64)
    l = BatchNormalization()(l)
    l = Activation(activation_fcn)(l)
    
    l = Conv3D(64, (3,3,3), strides=(2,2,2), padding='same')(l)
    l = VoxRes_module(l, num_channels=64)
    l = VoxRes_module(l, num_channels=64)  
    l = BatchNormalization()(l)
    l = Activation(activation_fcn)(l)
    
    inputs_down = AveragePooling3D((8,8,8))(inputs)
    l = concatenate([l, inputs_down])
    
    # split network for classification, positioning and shape prediction
    l_cla = Conv3D(64, (3,3,3), padding='same')(l)
    l_cla = VoxRes_module(l_cla, num_channels=64)
    l_cla = VoxRes_module(l_cla, num_channels=64)
    l_cla = BatchNormalization()(l_cla)
    l_cla = Activation(activation_fcn)(l_cla)
    l_cla = Conv3D(32, (3,3,3), padding='same')(l_cla)
    l_cla = BatchNormalization()(l_cla)
    l_cla = Activation(activation_fcn)(l_cla)
    l_cla = Conv3D(dets_per_region, (1,1,1), padding='same')(l_cla)
    out_cla = Activation('sigmoid')(l_cla)
    
    l_pos = Conv3D(64, (3,3,3), padding='same')(l)
    l_pos = VoxRes_module(l_pos, num_channels=64)
    l_pos = VoxRes_module(l_pos, num_channels=64)
    l_pos = BatchNormalization()(l_pos)
    l_pos = Activation(activation_fcn)(l_pos)
    l_pos = Conv3D(32, (3,3,3), padding='same')(l_pos)
    l_pos = BatchNormalization()(l_pos)
    l_pos = Activation(activation_fcn)(l_pos)
    l_pos = Conv3D(dets_per_region*3, (1,1,1), padding='same')(l_pos)
    out_pos = Activation('sigmoid')(l_pos)    
    
    l_shape = Conv3D(128, (3,3,3), padding='same')(l)
    l_shape = VoxRes_module(l_shape, num_channels=128)
    l_shape = VoxRes_module(l_shape, num_channels=128)
    l_shape = BatchNormalization()(l_shape)
    l_shape = Activation(activation_fcn)(l_shape)
    l_shape = Conv3D(256, (3,3,3), padding='same')(l_shape)
    l_shape = BatchNormalization()(l_shape)
    l_shape = Activation(activation_fcn)(l_shape)
    out_shape = Conv3D(dets_per_region*num_coefficients, (1,1,1), padding='same')(l_shape)
    
    outputs = concatenate([out_cla, out_pos, out_shape])
    
    model = Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer=RAdam(lr=0.001), loss=harmonic_loss3D_wrapper(dets_per_region=dets_per_region, \
                                                                 num_coefficients=num_coefficients, \
                                                                 weight_class=weight_class, \
                                                                 weight_pos=weight_pos, \
                                                                 weight_shape=weight_shape, \
                                                                 max_weight_scale=max_weight_scale))
        
    if verbose: model.summary()
    
    return model




