import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, Activation, Lambda, concatenate, Dropout


#-------------------------------------------------------------------------------
def SubpixelDownsampling(x, block_name=''):
  # Convert H,W,C - H/2,W/2,C*2
  depth = tf.keras.backend.int_shape(x)[-1]
  def space_to_depth(x):
        return tf.nn.space_to_depth(x, 2)
  x = Conv2D(filters=depth//2, kernel_size=(1,1), strides=1, padding='same', name=block_name+'_conv_sp')(x)
  x = Lambda(space_to_depth, name=block_name+'_shuffle_sp')(x)
  return x

def SubpixelUpsampling(x, block_name=''):
  # Convert H,W,C - H*2,W*2,C/2
  depth = tf.keras.backend.int_shape(x)[-1]
  def depth_to_space(x):
        return tf.nn.depth_to_space(x, 2)
  x = Conv2D(filters=depth*2, kernel_size=(1,1), strides=1, padding='same', name=block_name+'_conv_sp')(x)
  x = Lambda(depth_to_space, name=block_name+'shuffle_sp')(x)
  return x



#-------------------------------------------------------------------------------
def ResidualBlock(input_tensor, filters, strides, block_name=''):

  skip = input_tensor
  main = input_tensor

  if strides==2:
    skip = SubpixelDownsampling(skip, block_name+'_skip')
    main = SubpixelDownsampling(main, block_name+'_main')

  main = Conv2D(filters=filters, kernel_size=(3,3), strides=1, padding='same', name=block_name+'_conv1')(main)
  main = BatchNormalization(name=block_name+'_bn_1')(main)
  main = Activation('relu', name=block_name+'_relu')(main)

  main = Conv2D(filters=filters, kernel_size=(3,3), strides=1, padding='same', name=block_name+'_conv2')(main)
  main = BatchNormalization(name=block_name+'_bn_2')(main)

  out = Add(name=block_name+'_add')([main, skip])
  out = Activation('relu', name=block_name+'_relu_out')(out)

  return out



#-------------------------------------------------------------------------------
def get_model(input_shape=(512,512,3), drop=0.5):

  inputs = Input(input_shape)

  x0 = Conv2D(filters=36, activation='relu', kernel_size=(7,7), strides=1, padding='same', name='Block0')(inputs)

  x1 = ResidualBlock(x0, 36, strides=1, block_name='Block1')
  x1 = Dropout(drop)(x1)
  x1 = ResidualBlock(x1, 36, strides=1, block_name='Block2')
  x1 = Dropout(drop)(x1)

  x2 = ResidualBlock(x1, 72, strides=2, block_name='Block3')
  x2 = Dropout(drop)(x2)
  x2 = ResidualBlock(x2, 72, strides=1, block_name='Block4')
  x2 = Dropout(drop)(x2)

  x3 = ResidualBlock(x2, 144, strides=2, block_name='Block5')
  x3 = Dropout(drop)(x3)
  x3 = ResidualBlock(x3, 144, strides=1, block_name='Block6')
  x3 = Dropout(drop)(x3)

  x4 = ResidualBlock(x3, 288, strides=2, block_name='Block7')
  x4 = Dropout(drop)(x4)
  x4 = ResidualBlock(x4, 288, strides=1, block_name='Block8')
  x4 = Dropout(drop)(x4)

  x5 = ResidualBlock(x4, 576, strides=2, block_name='Block9')
  x5 = Dropout(drop)(x5)
  x5 = ResidualBlock(x5, 576, strides=1, block_name='Block10')
  x5 = Dropout(drop)(x5)

  y4 = SubpixelUpsampling(x5, block_name='UP1')
  y4 = concatenate([y4, x4])
  y4 = Conv2D(288, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y4)
  y4 = BatchNormalization() (y4)
  y4 = Dropout(drop)(y4)
  y4 = Conv2D(288, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y4)
  y4 = BatchNormalization() (y4)
  y4 = Dropout(drop)(y4)

  y3 = SubpixelUpsampling(y4, block_name='UP2')
  y3 = concatenate([y3, x3])
  y3 = Conv2D(144, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y3)
  y3 = BatchNormalization() (y3)
  y3 = Dropout(drop)(y3)
  y3 = Conv2D(144, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y3)
  y3 = BatchNormalization() (y3)
  y3 = Dropout(drop)(y3)

  y2 = SubpixelUpsampling(y3, block_name='UP3')
  y2 = concatenate([y2, x2])
  y2 = Conv2D(72, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y2)
  y2 = BatchNormalization() (y2)
  y2 = Dropout(drop)(y2)
  y2 = Conv2D(72, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y2)
  y2 = BatchNormalization() (y2)
  y2 = Dropout(drop)(y2)

  y1 = SubpixelUpsampling(y2, block_name='UP4')
  y1 = concatenate([y1, x1])
  y1 = Conv2D(36, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y1)
  y1 = BatchNormalization() (y1)
  y1 = Dropout(drop)(y1)
  y1 = Conv2D(36, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (y1)
  y1 = BatchNormalization() (y1)

  outputs =  Conv2D(1, (1, 1), activation='sigmoid') (y1)

  model = Model(inputs=[inputs], outputs=[outputs])

  return model
