from .registry import register

import typing as t
import tensorflow as tf
from tensorflow.keras import layers

from cyclegan.models import utils


@register('resnet')
def get_resnet(args, g_name: str = 'generator', d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                filters=args.num_filters,
                normalization=args.normalization,
                activation=args.activation,
                reduction_factor=args.reduction_factor,
                output_activation=utils.generator_activation(args),
                spectral_norm=args.spectral_norm,
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    filters=args.num_filters,
                    normalization=args.normalization,
                    activation=args.activation,
                    dropout=args.dropout,
                    patchgan=args.patchgan,
                    reduction_factor=args.reduction_factor,
                    output_activation=utils.discriminator_activation(args),
                    spectral_norm=args.spectral_norm,
                    name=d_name)
  return g, d


@register('agresnet')
def get_attention_gate_resnet(args,
                              g_name: str = 'generator',
                              d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                filters=args.num_filters,
                normalization=args.normalization,
                activation=args.activation,
                reduction_factor=args.reduction_factor,
                output_activation=utils.generator_activation(args),
                spectral_norm=args.spectral_norm,
                use_attention_gate=True,
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    filters=args.num_filters,
                    normalization=args.normalization,
                    activation=args.activation,
                    dropout=args.dropout,
                    patchgan=args.patchgan,
                    reduction_factor=args.reduction_factor,
                    output_activation=utils.discriminator_activation(args),
                    spectral_norm=args.spectral_norm,
                    name=d_name)
  return g, d


def additive_attention_gate(inputs,
                            shortcut,
                            filters: int,
                            normalization: str = 'instancenorm',
                            name: str = 'additive_attention_gate'):
  """
  Additive Attention Gate
  reference: https://arxiv.org/abs/1804.03999
  """
  g = utils.Conv(filters, kernel_size=1, name=f'{name}/shortcut/conv')(shortcut)
  g = utils.Normalization(normalization, name=f'{name}/shortcut/norm')(g)

  x = utils.Conv(filters, kernel_size=1, name=f'{name}/inputs/conv')(inputs)
  x = utils.Normalization(normalization, name=f'{name}/inputs/norm')(x)

  mask = layers.Add(name=f'{name}/add')([g, x])
  mask = utils.Activation('relu', name=f'{name}/relu')(mask)
  mask = utils.Conv(filters=1, kernel_size=1, name=f'{name}/conv')(mask)
  mask = utils.Normalization(normalization, name=f'{name}/norm')(mask)
  mask = utils.Activation('sigmoid', name=f'{name}/sigmoid')(mask)

  outputs = layers.Multiply(name=f'{name}/multiply')([shortcut, mask])
  outputs = layers.Concatenate(axis=-1,
                               name=f'{name}/concat')([inputs, outputs])
  return outputs


class ReflectionPadding(layers.Layer):

  def __init__(self, padding: int = 1, **kwargs):
    super(ReflectionPadding, self).__init__(**kwargs)
    self.pad = padding

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    if ndim == 2:
      self.paddings = [
          [0, 0],
          [self.pad, self.pad],
          [0, 0],
      ]
    else:
      self.paddings = [
          [0, 0],
          [self.pad, self.pad],
          [self.pad, self.pad],
          [0, 0],
      ]

  def call(self, inputs, **kwargs):
    return tf.pad(inputs, self.paddings, mode='REFLECT')


def residual_block(inputs,
                   kernel_size: t.Union[int, t.Tuple[int, int]] = 3,
                   strides: t.Union[int, t.Tuple[int, int]] = 1,
                   padding: str = 'valid',
                   use_bias: bool = False,
                   normalization: str = 'instancenorm',
                   activation: str = 'lrelu',
                   spectral_norm: bool = False,
                   name: str = 'block'):
  in_channels = inputs.shape[-1]
  outputs = inputs

  outputs = ReflectionPadding(name=f'{name}/pad_1')(outputs)
  outputs = utils.Conv(filters=in_channels,
                       kernel_size=kernel_size,
                       strides=strides,
                       padding=padding,
                       use_bias=use_bias,
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv_1')(outputs)
  outputs = utils.Normalization(normalization, name=f'{name}/norm_1')(outputs)
  outputs = utils.Activation(activation, name=f'{name}/activation_1')(outputs)

  outputs = ReflectionPadding(name=f'{name}/pad_2')(outputs)
  outputs = utils.Conv(in_channels,
                       kernel_size=kernel_size,
                       strides=strides,
                       padding=padding,
                       use_bias=use_bias,
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv_2')(outputs)
  outputs = layers.Add(name=f'{name}/add')([inputs, outputs])
  outputs = utils.Normalization(normalization, name=f'{name}/norm_2')(outputs)
  outputs = utils.Activation(activation, name=f'{name}/activation_2')(outputs)
  return outputs


def downsample(inputs,
               filters: int,
               kernel_size: t.Union[int, t.Tuple[int, int]] = 3,
               strides: t.Union[int, t.Tuple[int, int]] = 2,
               padding: str = 'same',
               use_bias: bool = False,
               normalization: str = 'instancenorm',
               activation: str = 'lrelu',
               dropout: float = 0.0,
               spectral_norm: bool = False,
               name='downsample'):
  outputs = utils.Conv(filters,
                       kernel_size=kernel_size,
                       strides=strides,
                       padding=padding,
                       use_bias=use_bias,
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv')(inputs)
  outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  outputs = utils.Activation(activation, name=f'{name}/activation')(outputs)
  outputs = utils.SpatialDropout(rate=dropout, name=f'{name}/dropout')(outputs)
  return outputs


def upsample(inputs,
             filters: int,
             kernel_size: t.Union[int, t.Tuple[int, int]] = 3,
             strides: t.Union[int, t.Tuple[int, int]] = 2,
             padding: str = 'same',
             use_bias: bool = False,
             normalization: str = None,
             activation: str = None,
             spectral_norm: bool = False,
             name: str = 'upsample'):
  outputs = utils.TransposeConv(filters,
                                kernel_size=kernel_size,
                                strides=strides,
                                padding=padding,
                                use_bias=use_bias,
                                spectral_norm=spectral_norm,
                                name=f'{name}/transpose_conv')(inputs)
  if normalization is not None:
    outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  if activation is not None:
    outputs = utils.Activation(activation, name=f'{name}/activation')(outputs)
  return outputs


def generator(input_shape,
              filters: int = 64,
              normalization: str = 'instancenorm',
              activation: str = 'lrelu',
              num_down_blocks: int = 2,
              num_residual_blocks: int = 9,
              num_up_blocks: int = 2,
              reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
              output_activation: str = 'linear',
              spectral_norm: bool = False,
              use_attention_gate: bool = False,
              name: str = 'generator'):
  inputs = tf.keras.Input(input_shape, name='inputs')
  outputs = inputs

  outputs = ReflectionPadding(padding=3, name='input/pad')(outputs)
  outputs = utils.Conv(filters=filters,
                       kernel_size=7,
                       use_bias=False,
                       spectral_norm=spectral_norm,
                       name='input/conv')(outputs)
  outputs = utils.Normalization(normalization, name='input/norm')(outputs)
  outputs = utils.Activation(activation, name='input/activation')(outputs)

  shortcuts = []
  for i in range(num_down_blocks):
    filters *= 2
    outputs = downsample(outputs,
                         filters=filters,
                         strides=reduction_factor,
                         normalization=normalization,
                         activation=activation,
                         spectral_norm=spectral_norm,
                         name=f'down_{i+1}')
    shortcuts.append(outputs)

  for i in range(num_residual_blocks):
    outputs = residual_block(outputs,
                             normalization=normalization,
                             activation=activation,
                             spectral_norm=spectral_norm,
                             name=f'block_{i+1}')

  shortcuts = shortcuts[::-1]
  for i in range(num_up_blocks):
    shortcut = shortcuts[i]
    if outputs.shape[1:] != shortcut.shape[1:]:
      outputs = utils.cropping(outputs, shortcut, name=f'up_{i+1}/crop')
    if use_attention_gate:
      outputs = additive_attention_gate(outputs,
                                        shortcut=shortcut,
                                        filters=outputs.shape[-1],
                                        normalization=normalization,
                                        name=f'up_{i+1}/attention')
    else:
      outputs = layers.Concatenate(axis=-1,
                                   name=f'up_{i+1}/concat')([outputs, shortcut])
    filters //= 2
    outputs = upsample(outputs,
                       filters=filters,
                       strides=reduction_factor,
                       normalization=normalization,
                       activation=activation,
                       spectral_norm=spectral_norm,
                       name=f'up_{i+1}')

  if outputs.shape[1:] != inputs.shape[1:]:
    outputs = utils.cropping(outputs, inputs, name=f'crop')

  outputs = ReflectionPadding(padding=3, name='output/pad')(outputs)
  outputs = utils.Conv(filters=input_shape[-1],
                       kernel_size=7,
                       spectral_norm=spectral_norm,
                       name='output/conv')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  assert inputs.shape[1:] == outputs.shape[1:]

  return tf.keras.models.Model(inputs=inputs, outputs=outputs, name=name)


def discriminator(input_shape,
                  filters: int = 64,
                  normalization: str = 'instancenorm',
                  activation: str = 'lrelu',
                  dropout: float = 0.0,
                  num_down_blocks: int = 3,
                  reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
                  patchgan: bool = False,
                  output_activation: str = 'linear',
                  spectral_norm: bool = False,
                  name: str = 'discriminator'):
  inputs = tf.keras.Input(input_shape, name='inputs')
  outputs = inputs

  outputs = utils.Conv(filters=filters,
                       kernel_size=4,
                       strides=reduction_factor,
                       padding='same',
                       spectral_norm=spectral_norm,
                       name='input/conv')(outputs)
  outputs = utils.Activation(activation, name='input/activation')(outputs)

  for i in range(num_down_blocks):
    filters *= 2
    outputs = downsample(outputs,
                         filters=filters,
                         kernel_size=4,
                         strides=reduction_factor if i < 2 else 1,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         spectral_norm=spectral_norm,
                         name=f'down_{i}')

  outputs = utils.Conv(filters=1,
                       kernel_size=4,
                       strides=1,
                       padding='same',
                       spectral_norm=spectral_norm,
                       name='output/conv')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  if not patchgan:
    outputs = layers.Flatten(name='output/flatten')(outputs)
    outputs = layers.Dense(1, name='output/dense')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  return tf.keras.models.Model(inputs=inputs, outputs=outputs, name=name)
