import typing as t
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, initializers


def generator_activation(args):
  activation = 'linear'
  if args.dataset == 'horse2zebra':
    activation = 'tanh'
  elif args.scaled_data:
    activation = 'sigmoid'
  return activation


def discriminator_activation(args):
  activation = 'sigmoid'
  if args.algorithm == 'wgangp':
    activation = 'linear'
  return activation


def is_power_of_two(x):
  ''' return True if x is power of 2 '''
  return (x and (not (x & (x - 1))))


def next_power_of_two(x):
  ''' return the next power of 2 number after x '''
  return 1 if x == 0 else 2**(x - 1).bit_length()


def cropping(a: tf.Tensor, b: tf.Tensor, name: str = 'cropping'):
  """ crop tensor a to have the same shape as tensor b """
  assert a.shape[1] >= b.shape[1] and a.shape[2] >= b.shape[2]
  h_crop = a.shape[1] - b.shape[1]
  w_crop = a.shape[2] - b.shape[2]
  top_crop = h_crop // 2
  bottom_crop = h_crop - top_crop
  left_crop = w_crop // 2
  right_crop = w_crop - left_crop
  cropping = ((top_crop, bottom_crop), (left_crop, right_crop))
  return layers.Cropping2D(cropping, name=name)(a)


def Activation(activation: str, **kwargs):
  if activation in ['lrelu', 'leakyrelu']:
    return layers.LeakyReLU(**kwargs)
  else:
    return layers.Activation(activation, **kwargs)


def Normalization(normalization: str, **kwargs):
  if normalization in ['layer_norm', 'layernorm']:
    return layers.LayerNormalization(**kwargs)
  elif normalization in ['batch_norm', 'batchnorm']:
    return layers.BatchNormalization(**kwargs)
  elif normalization in ['instance_norm', 'instancenorm']:
    return tfa.layers.InstanceNormalization(**kwargs)
  elif normalization in ['group_norm', 'groupnorm']:
    return tfa.layers.GroupNormalization(**kwargs)
  raise NameError(f'Normalization layer {normalization} not found.')


class PhaseShuffle(layers.Layer):
  ''' Phase Shuffle introduced in the WaveGAN paper so that the discriminator 
  are less sensitive toward periodic patterns which occurs quite frequently in
  signal data '''

  def __init__(self, input_shape, m=0, mode='reflect'):
    super().__init__()
    self.shape = input_shape
    self.m = m
    self.mode = mode

  def call(self, inputs, **kwargs):
    if self.m == 0:
      return inputs

    w = self.shape[1]

    # shift on the temporal dimension
    shift = tf.random.uniform([],
                              minval=-self.m,
                              maxval=self.m + 1,
                              dtype=tf.int32)

    if shift > 0:
      # shift to the right
      paddings = [[0, 0], [0, shift], [0, 0]]
      start, end = shift, w + shift
    else:
      # shift to the left
      paddings = [[0, 0], [tf.math.abs(shift), 0], [0, 0]]
      start, end = 0, w

    outputs = tf.pad(inputs, paddings=paddings, mode=self.mode)
    outputs = outputs[:, start:end, :]
    return tf.ensure_shape(outputs, shape=self.shape)


class Dense(layers.Layer):

  def __init__(self,
               units: int,
               activation: str = None,
               use_bias: bool = True,
               kernel_initializer: str = 'glorot_uniform',
               spectral_norm: bool = False,
               **kwargs):
    super(Dense, self).__init__(**kwargs)
    self.units = units
    self.activation = activation
    self.use_bias = use_bias
    self.kernel_initializer = kernel_initializer
    self.spectral_norm = spectral_norm

  def build(self, input_shape):
    self.dense = layers.Dense(units=self.units,
                              activation=self.activation,
                              use_bias=self.use_bias,
                              kernel_initializer=self.kernel_initializer)
    if self.spectral_norm:
      self.dense = tfa.layers.SpectralNormalization(self.dense)

  def call(self, inputs, **kwargs):
    return self.dense(inputs)


class Conv(layers.Layer):

  def __init__(self,
               filters: int,
               kernel_size: int,
               strides: t.Union[int, t.Tuple[int, int]] = 1,
               padding: str = 'valid',
               dilation_rate: int = 1,
               activation: str = None,
               use_bias: bool = True,
               spectral_norm: bool = False,
               **kwargs):
    super(Conv, self).__init__(**kwargs)
    self.filters = filters
    self.kernel_size = kernel_size
    self.strides = strides
    self.padding = padding
    self.dilation_rate = dilation_rate
    self.activation = activation
    self.use_bias = use_bias
    self.kernel_initializer = initializers.RandomNormal(mean=0.0, stddev=0.02)
    self.spectral_norm = spectral_norm

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    layer = layers.Conv1D if ndim == 2 else layers.Conv2D
    self.conv = layer(filters=self.filters,
                      kernel_size=self.kernel_size,
                      strides=self.strides,
                      padding=self.padding,
                      dilation_rate=self.dilation_rate,
                      activation=self.activation,
                      use_bias=self.use_bias,
                      kernel_initializer=self.kernel_initializer)
    if self.spectral_norm:
      self.conv = tfa.layers.SpectralNormalization(self.conv)

  def call(self, inputs, **kwargs):
    return self.conv(inputs)


class TransposeConv(layers.Layer):

  def __init__(self,
               filters: int,
               kernel_size: t.Union[int, t.Tuple[int, int]],
               strides: t.Union[int, t.Tuple[int, int]] = 1,
               padding: str = 'valid',
               dilation_rate: int = 1,
               activation: str = None,
               use_bias: bool = True,
               spectral_norm: bool = False,
               **kwargs):
    super(TransposeConv, self).__init__(**kwargs)
    self.filters = filters
    self.kernel_size = kernel_size
    self.strides = strides
    self.padding = padding
    self.dilation_rate = dilation_rate
    self.activation = activation
    self.use_bias = use_bias
    self.kernel_initializer = initializers.RandomNormal(mean=0.0, stddev=0.02)
    self.spectral_norm = spectral_norm

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    layer = layers.Conv1DTranspose if ndim == 2 else layers.Conv2DTranspose
    self.transpose_conv = layer(filters=self.filters,
                                kernel_size=self.kernel_size,
                                strides=self.strides,
                                padding=self.padding,
                                dilation_rate=self.dilation_rate,
                                activation=self.activation,
                                use_bias=self.use_bias,
                                kernel_initializer=self.kernel_initializer)
    if self.spectral_norm:
      self.transpose_conv = tfa.layers.SpectralNormalization(
          self.transpose_conv)

  def call(self, inputs, **kwargs):
    return self.transpose_conv(inputs)


class MaxPool(layers.Layer):

  def __init__(self,
               pool_size: t.Union[int, t.Tuple[int, int]] = 2,
               strides: int = None,
               padding: str = 'same',
               name='MaxPool',
               **kwargs):
    super(MaxPool, self).__init__(name=name, **kwargs)
    self.pool_size = pool_size
    self.strides = strides
    self.padding = padding

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    layer = layers.MaxPool1D if ndim == 2 else layers.MaxPool2D
    self.max_pool = layer(pool_size=self.pool_size,
                          strides=self.strides,
                          padding=self.padding)

  def call(self, inputs, **kwargs):
    return self.max_pool(inputs)


class GlobalMaxPool(layers.Layer):

  def __init__(self, name='GlobalMaxPool', **kwargs):
    super(GlobalMaxPool, self).__init__(name=name, **kwargs)

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    self.global_max_pool = layers.GlobalMaxPool1D() \
      if ndim == 2 else layers.GlobalMaxPool2D()

  def call(self, inputs, **kwargs):
    return self.global_max_pool(inputs)


class GlobalAvgPool(layers.Layer):

  def __init__(self, name='GlobalAvgPool', **kwargs):
    super(GlobalAvgPool, self).__init__(name=name, **kwargs)

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    self.global_avg_pool = layers.GlobalAvgPool1D() \
      if ndim == 2 else layers.GlobalAvgPool2D()

  def call(self, inputs, **kwargs):
    return self.global_avg_pool(inputs)


class ZeroPadding(layers.Layer):

  def __init__(self, padding=1, name='ZeroPadding', **kwargs):
    super(ZeroPadding, self).__init__(name=name, **kwargs)
    self.padding = padding

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    layer = layers.ZeroPadding1D if ndim == 2 else layers.ZeroPadding2D
    self.zero_padding = layer(padding=self.padding)

  def call(self, inputs, **kwargs):
    return self.zero_padding(inputs)


class SpatialDropout(layers.Layer):

  def __init__(self, rate: float, name='SpatialDropout', **kwargs):
    super(SpatialDropout, self).__init__(name=name, **kwargs)
    self.rate = rate

  def build(self, input_shape):
    assert 2 <= (ndim := len(input_shape[1:])) <= 3
    dropout = layers.SpatialDropout1D if ndim == 2 else layers.SpatialDropout2D
    self.dropout = dropout(rate=self.rate)

  def call(self, inputs, **kwargs):
    return self.dropout(inputs, **kwargs)
