from .registry import register

import typing as t
import tensorflow as tf
from tensorflow.keras import layers

from cyclegan.models import utils


@register('pix2pix')
def get_models(args, g_name: str = 'generator', d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                num_filters=args.num_filters,
                activation=args.activation,
                normalization=args.normalization,
                dropout=args.dropout,
                reduction_factor=args.reduction_factor,
                output_activation=utils.generator_activation(args),
                spectral_norm=args.spectral_norm,
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    num_filters=args.num_filters,
                    activation=args.activation,
                    normalization=args.normalization,
                    dropout=args.dropout,
                    patchgan=args.patchgan,
                    reduction_factor=args.reduction_factor,
                    output_activation=utils.discriminator_activation(args),
                    spectral_norm=args.spectral_norm,
                    name=d_name)
  return g, d


def downsample(inputs,
               filters: int,
               kernel_size: int,
               strides: t.Union[int, t.Tuple[int, int]] = 2,
               normalization: str = None,
               activation: str = None,
               dropout: float = 0.,
               spectral_norm: bool = False,
               name: str = 'downsample'):
  outputs = utils.Conv(filters,
                       kernel_size,
                       strides=strides,
                       padding='same',
                       use_bias=False,
                       spectral_norm=spectral_norm,
                       name=f'{name}/conv')(inputs)
  if normalization is not None:
    outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  if activation is not None:
    outputs = utils.Activation(activation, name=f'{name}/activation')(outputs)
  if dropout > 0:
    outputs = layers.Dropout(dropout, name=f'{name}/dropout')(outputs)
  return outputs


def upsample(inputs,
             filters: int,
             kernel_size: int,
             strides: t.Union[int, t.Tuple[int, int]] = 2,
             normalization: str = None,
             activation: str = None,
             dropout: float = 0.,
             spectral_norm: bool = False,
             name='upsample'):
  outputs = utils.TransposeConv(filters,
                                kernel_size,
                                strides=strides,
                                padding='same',
                                use_bias=False,
                                spectral_norm=spectral_norm,
                                name=f'{name}/transpose_conv')(inputs)
  if normalization is not None:
    outputs = utils.Normalization(normalization, name=f'{name}/norm')(outputs)
  if activation is not None:
    outputs = utils.Activation(activation, name=f'{name}/activation')(outputs)
  if dropout > 0:
    outputs = layers.Dropout(dropout, name=f'{name}/dropout')(outputs)
  return outputs


def generator(input_shape,
              num_filters: int = 16,
              activation: str = 'lrelu',
              normalization: str = 'instancenorm',
              dropout: float = 0.0,
              reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
              output_activation: str = 'linear',
              spectral_norm: bool = False,
              name: str = 'generator'):
  inputs = tf.keras.Input(input_shape, name=f'{name}/inputs')
  outputs = inputs

  padding = None
  if (len(inputs.shape) == 4) and not utils.is_power_of_two(inputs.shape[2]):
    padding = ((0, 0),
               (0, utils.next_power_of_two(inputs.shape[2]) - inputs.shape[2]))
    outputs = layers.ZeroPadding2D(padding, name='input/zero_padding')(outputs)

  skips = []

  filters = [num_filters * (2**i) for i in [0, 1, 2, 3, 3, 3]]
  for i in range(len(filters)):
    outputs = downsample(outputs,
                         filters=filters[i],
                         kernel_size=4,
                         strides=reduction_factor,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         spectral_norm=spectral_norm,
                         name=f'down_{i}')
    skips.append(outputs)

  for i in range(len(filters) - 2, -1, -1):
    outputs = upsample(outputs,
                       filters=filters[i],
                       kernel_size=4,
                       strides=reduction_factor,
                       normalization=normalization,
                       activation=activation,
                       dropout=dropout,
                       spectral_norm=spectral_norm,
                       name=f'up_{i}')
    outputs = layers.Concatenate(name=f'{name}/concat_{i}')([outputs, skips[i]])

  outputs = utils.TransposeConv(filters=input_shape[-1],
                                kernel_size=4,
                                strides=2,
                                padding='same',
                                spectral_norm=spectral_norm,
                                name='output/conv')(outputs)

  if padding is not None:
    outputs = layers.Cropping2D(padding, name='output/cropping')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  assert inputs.shape[1:] == outputs.shape[1:]

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)


def discriminator(input_shape,
                  num_filters: int = 16,
                  activation: str = 'lrelu',
                  normalization: str = 'instancenorm',
                  dropout: float = 0.0,
                  reduction_factor: t.Union[int, t.Tuple[int, int]] = 2,
                  patchgan: bool = False,
                  output_activation: str = 'linear',
                  spectral_norm: bool = False,
                  name='discriminator'):
  inputs = tf.keras.Input(input_shape, name='inputs')
  outputs = inputs

  filters = [num_filters * (2**i) for i in range(4)]
  for i in range(len(filters)):
    outputs = downsample(outputs,
                         filters=filters[i],
                         kernel_size=4,
                         strides=reduction_factor,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         spectral_norm=spectral_norm,
                         name=f'down_{i}')

  outputs = utils.ZeroPadding(padding=1, name='output/zero_padding_1')(outputs)
  outputs = utils.Conv(filters=256,
                       kernel_size=4,
                       strides=1,
                       use_bias=False,
                       spectral_norm=spectral_norm,
                       name='output/conv_1')(outputs)
  outputs = utils.Normalization(normalization, name='output/norm_1')(outputs)
  outputs = utils.Activation(activation, name='output/activation_1')(outputs)
  outputs = utils.ZeroPadding(padding=1, name='output/zero_padding_2')(outputs)
  outputs = utils.Conv(filters=1,
                       kernel_size=4,
                       strides=1,
                       spectral_norm=spectral_norm,
                       name='output/conv_2')(outputs)

  if not patchgan:
    outputs = layers.Flatten(name='output/flatten')(outputs)
    outputs = utils.Dense(units=1,
                          spectral_norm=spectral_norm,
                          name='output/dense')(outputs)

  outputs = utils.Activation('linear', name='output/attention')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
