from .registry import register

import typing as t
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow_addons.layers import GELU

from cyclegan.models import utils


@register('mlp_mixer')
def get_mixer(args, g_name: str = 'generator', d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                filters=args.num_filters,
                reduction_factor=args.reduction_factor,
                output_activation=utils.generator_activation(args),
                use_attention_gate=True,
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    filters=args.num_filters,
                    patchgan=args.patchgan,
                    output_activation=utils.discriminator_activation(args),
                    name=d_name)
  return g, d


def attention_gate(inputs,
                   shortcut,
                   filters: int,
                   normalization: str = 'instancenorm',
                   name: str = 'attention_block'):
  """
  Additive Attention Gate
  reference: https://arxiv.org/abs/1804.03999
  """
  g = utils.Conv(filters, kernel_size=1, name=f'{name}/g/conv')(shortcut)
  g = utils.Normalization(normalization, name=f'{name}/g/norm')(g)

  x = utils.Conv(filters, kernel_size=1, name=f'{name}/x/conv')(inputs)
  x = utils.Normalization(normalization, name=f'{name}/x/norm')(x)

  alpha = layers.Add(name=f'{name}/add')([g, x])
  alpha = utils.Activation('relu', name=f'{name}/relu')(alpha)
  alpha = utils.Conv(filters=1, kernel_size=1, name=f'{name}/psi')(alpha)
  alpha = utils.Normalization(normalization, name=f'{name}/norm')(alpha)
  alpha = utils.Activation('sigmoid', name=f'{name}/sigmoid')(alpha)
  outputs = layers.Multiply(name=f'{name}/multiply')([inputs, alpha])
  return outputs


def mlp_block(inputs,
              units: int,
              dropout: float = 0.4,
              name: str = 'mlp_block'):
  outputs = layers.Dense(units, name=f'{name}/dense_1')(inputs)
  outputs = GELU(name=f'{name}/gelu')(outputs)
  outputs = layers.Dense(inputs.shape[-1], name=f'{name}/dense_2')(outputs)
  outputs = layers.Dropout(dropout, name=f'{name}/dropout')(outputs)
  return outputs


def token_mixing(inputs, units: int, name: str = 'token_mixing'):
  outputs = layers.LayerNormalization(epsilon=1e-6,
                                      name=f'{name}/layer_norm')(inputs)
  outputs = layers.Permute(dims=[2, 1], name=f'{name}/permute')(outputs)
  outputs = mlp_block(outputs, units, name=f'{name}/mlp')
  return outputs


def channel_mixing(inputs, units: int, name: str = 'channel_mixing'):
  outputs = layers.LayerNormalization(epsilon=1e-6,
                                      name=f'{name}/layer_norm')(inputs)
  outputs = mlp_block(outputs, units, name=f'{name}/mlp')
  return outputs


def mixer_block(inputs,
                token_units: int,
                channel_units: int,
                name: str = 'mixer_block'):
  token_outputs = token_mixing(inputs, token_units, name=f'{name}/token_mixing')
  token_outputs = layers.Permute(dims=[2, 1],
                                 name=f'{name}/permute')(token_outputs)
  token_outputs = layers.Add(name=f'{name}/add_1')([inputs, token_outputs])
  channel_outputs = channel_mixing(token_outputs,
                                   channel_units,
                                   name=f'{name}/channel_mixing')
  outputs = layers.Add(name=f'{name}/add_2')([channel_outputs, token_outputs])
  return outputs


class ReflectionPadding(layers.Layer):

  def __init__(self, padding: int = 1, **kwargs):
    super(ReflectionPadding, self).__init__(**kwargs)
    self.pad = padding

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    if ndim == 2:
      self.paddings = [
          [0, 0],
          [self.pad, self.pad],
          [0, 0],
      ]
    else:
      self.paddings = [
          [0, 0],
          [self.pad, self.pad],
          [self.pad, self.pad],
          [0, 0],
      ]

  def call(self, inputs, **kwargs):
    return tf.pad(inputs, self.paddings, mode='REFLECT')


def downsample(inputs,
               filters: int,
               kernel_size: t.Union[int, t.Tuple[int, int]] = 3,
               strides: t.Union[int, t.Tuple[int, int]] = 2,
               padding: str = 'same',
               use_bias: bool = False,
               normalization: str = 'instancenorm',
               activation: str = 'lrelu',
               dropout: float = 0.0,
               spectral_norm: bool = False,
               name='downsample'):
  outputs = utils.Conv(filters,
                       kernel_size=kernel_size,
                       strides=strides,
                       padding=padding,
                       use_bias=use_bias,
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv')(inputs)
  outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  outputs = utils.Activation(activation, name=f'{name}/activation')(outputs)
  if dropout > 0:
    outputs = layers.Dropout(rate=dropout, name=f'{name}/dropout')(outputs)
  return outputs


def upsample(inputs,
             filters: int,
             kernel_size: t.Union[int, t.Tuple[int, int]] = 3,
             strides: t.Union[int, t.Tuple[int, int]] = 2,
             padding: str = 'same',
             use_bias: bool = False,
             normalization: str = None,
             activation: str = None,
             spectral_norm: bool = False,
             name: str = 'upsample'):
  outputs = utils.TransposeConv(filters,
                                kernel_size=kernel_size,
                                strides=strides,
                                padding=padding,
                                use_bias=use_bias,
                                spectral_norm=spectral_norm,
                                name=f'{name}/transpose_conv')(inputs)
  if normalization is not None:
    outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  if activation is not None:
    outputs = utils.Activation(activation, name=f'{name}/activation')(outputs)
  return outputs


def generator(input_shape,
              filters: int = 64,
              normalization: str = 'instancenorm',
              activation: str = 'lrelu',
              patch_dims: tuple = (8, 8),
              token_units: int = 64,
              channel_units: int = 128,
              num_down_blocks: int = 2,
              num_mixer_blocks: int = 9,
              num_up_blocks: int = 2,
              reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
              output_activation: str = 'linear',
              use_attention_gate: bool = False,
              name: str = 'generator'):
  inputs = tf.keras.Input(input_shape, name='inputs')

  outputs = ReflectionPadding(padding=3,
                              name='input/reflection_padding')(inputs)
  outputs = utils.Conv(filters=filters,
                       kernel_size=7,
                       use_bias=False,
                       name='input/conv')(outputs)
  outputs = utils.Normalization(normalization, name='input/norm')(outputs)
  outputs = utils.Activation(activation, name='input/activation')(outputs)

  shortcuts = []
  for i in range(num_down_blocks):
    filters *= 2
    outputs = downsample(outputs,
                         filters=filters,
                         strides=reduction_factor,
                         normalization=normalization,
                         activation=activation,
                         name=f'down_{i+1}')
    shortcuts.append(outputs)

  # convert tensor to (N, patch_dim, patch_dim, channels)
  outputs = layers.Conv2D(filters * 4,
                          kernel_size=patch_dims,
                          strides=patch_dims,
                          padding='valid',
                          name='conv')(outputs)
  shape = outputs.shape[1:]
  # convert tensor to (N, patch_dim*patch_dim, channels)
  outputs = layers.Reshape(target_shape=(outputs.shape[1] * outputs.shape[2],
                                         outputs.shape[3]),
                           name='reshape_1')(outputs)

  for i in range(num_mixer_blocks):
    outputs = mixer_block(outputs,
                          token_units=token_units,
                          channel_units=channel_units,
                          name=f'mixer_{i+1}')

  # convert tensor to (N, patch_dim, patch_dim, channels)
  outputs = layers.Reshape(target_shape=shape, name='reshape_2')(outputs)
  # convert tensor to (N, H, W, C)
  outputs = layers.Conv2DTranspose(filters=filters,
                                   kernel_size=patch_dims,
                                   strides=patch_dims,
                                   padding='valid',
                                   name='deconv')(outputs)

  shortcuts = shortcuts[::-1]
  for i in range(num_up_blocks):
    shortcut = shortcuts[i]
    if inputs.shape[1:] != shortcut.shape[1:]:
      outputs = utils.cropping(outputs, shortcut, name=f'up_{i+1}/crop')
    if use_attention_gate:
      outputs = attention_gate(outputs,
                               shortcut=shortcut,
                               filters=outputs.shape[-1],
                               normalization=normalization,
                               name=f'up_{i+1}/attention')
    filters //= 2
    outputs = upsample(outputs,
                       filters=filters,
                       strides=reduction_factor,
                       normalization=normalization,
                       activation=activation,
                       name=f'up_{i+1}')

  outputs = ReflectionPadding(padding=3,
                              name='output/reflection_padding')(outputs)
  outputs = utils.Conv(filters=input_shape[-1],
                       kernel_size=7,
                       padding='valid',
                       name='output/conv')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  assert inputs.shape[1:] == outputs.shape[1:]

  return tf.keras.models.Model(inputs=inputs, outputs=outputs, name=name)


def discriminator(input_shape,
                  filters: int = 64,
                  patch_dims: tuple = (8, 8),
                  token_units: int = 64,
                  channel_units: int = 128,
                  num_down_blocks: int = 3,
                  patchgan: bool = False,
                  output_activation: str = 'linear',
                  name: str = 'discriminator'):
  inputs = tf.keras.Input(input_shape, name='inputs')

  outputs = layers.Conv2D(filters,
                          kernel_size=patch_dims,
                          strides=patch_dims,
                          name='conv')(inputs)
  outputs = layers.Reshape(target_shape=(outputs.shape[1] * outputs.shape[2],
                                         outputs.shape[3]),
                           name='reshape_1')(outputs)

  for i in range(num_down_blocks):
    outputs = mixer_block(outputs,
                          token_units=token_units,
                          channel_units=channel_units,
                          name=f'down_{i}')

  outputs = layers.LayerNormalization(epsilon=1e-6, name='layer_norm')(outputs)
  outputs = layers.GlobalAveragePooling1D(name='avg_pooling')(outputs)
  outputs = layers.Dense(1, name='dense')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  return tf.keras.models.Model(inputs=inputs, outputs=outputs, name=name)
