from .registry import register

import tensorflow as tf
from tensorflow.keras import layers

from cyclegan.models import utils


@register('identity')
def get_models(args, g_name: str = 'generator', d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                output_activation='linear',
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    output_activation=utils.discriminator_activation(args),
                    name=d_name)
  return g, d


class Identity(layers.Layer):

  def __init__(self, name='identity', **kwargs):
    super(Identity, self).__init__(name=name, **kwargs)

  def build(self, input_shape):
    self.weight = self.add_weight('weight',
                                  shape=input_shape[1:],
                                  trainable=True)

  def call(self, inputs, **kwargs):
    outputs = inputs + self.weight
    outputs = outputs - self.weight
    return outputs


def generator(input_shape, output_activation: str = 'linear', name='generator'):
  inputs = tf.keras.Input(shape=input_shape, name='inputs')

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(inputs)

  assert inputs.shape[1:] == outputs.shape[1:]

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)


def discriminator(input_shape,
                  output_activation: str = 'linear',
                  name='discriminator'):
  inputs = tf.keras.Input(input_shape, name='inputs')

  outputs = utils.GlobalAvgPool(name='output/global_pool')(inputs)
  outputs = layers.Dense(1, name='output/dense')(outputs)

  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
