from .registry import register

import typing as t
import tensorflow as tf
from tensorflow.keras import layers

from cyclegan.models import utils


@register('unet')
def get_unet(args, g_name: str = 'generator', d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                num_filters=args.num_filters,
                activation=args.activation,
                normalization=args.normalization,
                dropout=args.dropout,
                reduction_factor=args.reduction_factor,
                output_activation=utils.generator_activation(args),
                spectral_norm=args.spectral_norm,
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    num_filters=args.num_filters,
                    activation=args.activation,
                    normalization=args.normalization,
                    dropout=args.dropout,
                    reduction_factor=args.reduction_factor,
                    patchgan=args.patchgan,
                    output_activation=utils.discriminator_activation(args),
                    spectral_norm=args.spectral_norm,
                    name=d_name)
  return g, d


@register('agunet')
def get_attention_gates_unet(args,
                             g_name: str = 'generator',
                             d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                num_filters=args.num_filters,
                activation=args.activation,
                normalization=args.normalization,
                dropout=args.dropout,
                reduction_factor=args.reduction_factor,
                use_attention_gate=True,
                output_activation=utils.generator_activation(args),
                spectral_norm=args.spectral_norm,
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    num_filters=args.num_filters,
                    activation=args.activation,
                    normalization=args.normalization,
                    dropout=args.dropout,
                    reduction_factor=args.reduction_factor,
                    patchgan=args.patchgan,
                    output_activation=utils.discriminator_activation(args),
                    spectral_norm=args.spectral_norm,
                    name=d_name)
  return g, d


def channel_attention(inputs,
                      filters: int,
                      reduction_ratio: int,
                      name: str = 'ChannelAttention'):
  reshape = layers.Reshape((1, 1, filters), name=f'{name}/reshape')
  dense1 = layers.Dense(filters // reduction_ratio, name=f'{name}/dense_1')
  activation = utils.Activation('relu', name=f'{name}/relu')
  dense2 = layers.Dense(filters, name=f'{name}/dense_2')

  avg_pool = utils.GlobalAvgPool(name=f'{name}/average_pool')(inputs)
  avg_pool = reshape(avg_pool)
  avg_pool = dense1(avg_pool)
  avg_pool = activation(avg_pool)
  avg_pool = dense2(avg_pool)

  max_pool = utils.GlobalMaxPool(name=f'{name}/max_pool')(inputs)
  max_pool = reshape(max_pool)
  max_pool = dense1(max_pool)
  max_pool = activation(max_pool)
  max_pool = dense2(max_pool)

  outputs = layers.Add(name=f'{name}/add')([avg_pool, max_pool])
  outputs = utils.Activation('sigmoid', name=f'{name}/sigmoid')(outputs)
  outputs = layers.Multiply(name=f'{name}/multiply')([outputs, inputs])
  return outputs


def spatial_attention(inputs,
                      kernel_size: t.Union[int, t.Tuple[int, int]] = 7,
                      normalization: str = 'instancenorm',
                      name: str = 'SpatialAttention'):
  avg_pool = layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True),
                           name=f'{name}/average_pool')(inputs)
  max_pool = layers.Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True),
                           name=f'{name}/max_pool')(inputs)
  outputs = layers.Concatenate(axis=-1,
                               name=f'{name}/concat')([avg_pool, max_pool])
  outputs = utils.Conv(filters=1,
                       kernel_size=kernel_size,
                       strides=1,
                       padding='same',
                       use_bias=False,
                       name=f'{name}/conv')(outputs)
  outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  outputs = utils.Activation('sigmoid', name=f'{name}/sigmoid')(outputs)
  outputs = layers.Multiply(name=f'{name}/multiply')([outputs, inputs])
  return outputs


def cbam(inputs,
         filters: int,
         normalization: str = 'instancenorm',
         reduction_ratio: int = 16,
         name: str = 'CBAM'):
  """
  Convolutional Block Attention Modules
  Paper: https://arxiv.org/abs/1807.06521
  Official PyTorch implementation: https://github.com/Jongchan/attention-module
  """
  outputs = channel_attention(inputs=inputs,
                              filters=filters,
                              reduction_ratio=reduction_ratio,
                              name=f'{name}/channel')
  outputs = spatial_attention(inputs=outputs,
                              normalization=normalization,
                              name=f'{name}/spatial')
  return outputs


def attention_gate(inputs,
                   shortcut,
                   filters: int,
                   normalization: str = 'batchnorm',
                   name: str = 'attention_block'):
  """
  Additive Attention Gate
  reference: https://arxiv.org/abs/1804.03999
  """
  g = utils.Conv(filters, kernel_size=1, name=f'{name}/g/conv')(shortcut)
  g = utils.Normalization(normalization, name=f'{name}/g/norm')(g)

  x = utils.Conv(filters, kernel_size=1, name=f'{name}/x/conv')(inputs)
  x = utils.Normalization(normalization, name=f'{name}/x/norm')(x)

  alpha = layers.Add(name=f'{name}/add')([g, x])
  alpha = utils.Activation('relu', name=f'{name}/relu')(alpha)
  alpha = utils.Conv(filters=1, kernel_size=1, name=f'{name}/psi')(alpha)
  alpha = utils.Normalization(normalization, name=f'{name}/norm')(alpha)
  alpha = utils.Activation('sigmoid', name=f'{name}/sigmoid')(alpha)
  outputs = layers.Multiply(name=f'{name}/multiply')([inputs, alpha])
  return outputs


def conv_block(inputs,
               filters: int,
               kernel_size: int,
               normalization: str = 'batchnorm',
               activation: str = 'lrelu',
               dropout: float = 0.0,
               spectral_norm: bool = False,
               name: str = 'conv_block'):
  outputs = utils.Conv(filters,
                       kernel_size,
                       strides=1,
                       padding='same',
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv_1')(inputs)
  outputs = utils.Normalization(normalization, name=f'{name}/norm_1')(outputs)
  outputs = utils.Activation(activation, name=f'{name}/activation_1')(outputs)
  if dropout > 0:
    outputs = layers.Dropout(dropout, name=f'{name}/dropout_1')(outputs)

  outputs = utils.Conv(filters,
                       kernel_size,
                       strides=1,
                       padding='same',
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv_2')(outputs)
  outputs = utils.Normalization(normalization, name=f'{name}/norm_2')(outputs)
  outputs = utils.Activation(activation, name=f'{name}/activation_2')(outputs)
  if dropout > 0:
    outputs = layers.Dropout(dropout, name=f'{name}/dropout_2')(outputs)

  return outputs


def generator(input_shape,
              num_filters: int = 16,
              normalization: str = 'instancenorm',
              activation: str = 'lrelu',
              dropout: float = 0.0,
              reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
              use_attention_gate: bool = False,
              output_activation: str = 'linear',
              spectral_norm: bool = False,
              name: str = 'generator'):
  inputs = tf.keras.Input(input_shape, name=f'{name}/inputs')
  outputs = inputs

  padding = None
  if (len(inputs.shape) == 4) and not utils.is_power_of_two(inputs.shape[2]):
    padding = ((0, 0),
               (0, utils.next_power_of_two(inputs.shape[2]) - inputs.shape[2]))
    outputs = layers.ZeroPadding2D(padding, name='input/zero_padding')(outputs)

  shortcuts = []
  filters = [num_filters * (2**i) for i in [0, 1, 2]]

  for i in range(len(filters)):
    outputs = conv_block(outputs,
                         filters=filters[i],
                         kernel_size=3,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         spectral_norm=spectral_norm,
                         name=f'down_{i}/conv_block')
    shortcuts.append(outputs)
    outputs = utils.MaxPool(pool_size=reduction_factor,
                            name=f'down_{i}/max_pool')(outputs)

  outputs = conv_block(outputs,
                       filters=filters[-1],
                       kernel_size=3,
                       normalization=normalization,
                       activation=activation,
                       dropout=dropout,
                       spectral_norm=spectral_norm,
                       name=f'bottleneck')

  for i in range(len(filters) - 1, -1, -1):
    outputs = utils.TransposeConv(filters=filters[i],
                                  kernel_size=2,
                                  strides=reduction_factor,
                                  padding='same',
                                  spectral_norm=spectral_norm,
                                  name=f'up_{i}/transpose')(outputs)
    shortcut = shortcuts[i]
    if use_attention_gate:
      shortcut = attention_gate(outputs,
                                shortcut=shortcut,
                                filters=filters[i],
                                normalization=normalization,
                                name=f'up_{i}/attention')
    outputs = layers.Concatenate(name=f'up_{i}/concat')([outputs, shortcut])
    outputs = conv_block(outputs,
                         filters=filters[i],
                         kernel_size=3,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         spectral_norm=spectral_norm,
                         name=f'up_{i}/conv_block')

  outputs = utils.Conv(filters=input_shape[-1],
                       kernel_size=1,
                       spectral_norm=spectral_norm,
                       name='output/conv')(outputs)

  if padding is not None:
    outputs = layers.Cropping2D(padding, name='output/cropping')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  assert inputs.shape[1:] == outputs.shape[1:]

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)


def discriminator(input_shape,
                  num_filters: int = 16,
                  normalization: str = 'instancenorm',
                  activation: str = 'lrelu',
                  dropout: float = 0.0,
                  reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
                  patchgan: bool = False,
                  output_activation: str = 'linear',
                  spectral_norm: bool = False,
                  name='discriminator'):
  inputs = tf.keras.Input(input_shape, name='inputs')
  outputs = inputs

  if (len(inputs.shape) == 4) and not utils.is_power_of_two(inputs.shape[2]):
    padding = ((0, 0),
               (0, utils.next_power_of_two(inputs.shape[2]) - inputs.shape[2]))
    outputs = layers.ZeroPadding2D(padding, name='input/zero_padding')(outputs)

  filters = [num_filters * (2**i) for i in range(3)]
  for i in range(len(filters)):
    outputs = conv_block(outputs,
                         filters=filters[i],
                         kernel_size=3,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         spectral_norm=spectral_norm,
                         name=f'conv_block_{i}')
    outputs = utils.MaxPool(pool_size=reduction_factor,
                            name=f'max_pool_{i}')(outputs)

  outputs = utils.Conv(filters=1,
                       kernel_size=3,
                       strides=1,
                       padding='same',
                       spectral_norm=spectral_norm,
                       name='output/conv')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  if not patchgan:
    outputs = layers.Flatten(name='output/flatten')(outputs)
    outputs = utils.Dense(units=1,
                          spectral_norm=spectral_norm,
                          name='output/dense')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
