from .registry import register

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

from cyclegan.models import utils


@register('mlp')
def get_models(args, g_name: str = 'generator', d_name: str = 'discriminator'):
  g = generator(input_shape=args.input_shape,
                num_units=args.num_filters,
                activation=args.activation,
                dropout=args.dropout,
                output_activation=utils.generator_activation(args),
                name=g_name)
  d = discriminator(input_shape=args.input_shape,
                    num_units=args.num_filters,
                    activation=args.activation,
                    dropout=args.dropout,
                    output_activation=utils.discriminator_activation(args),
                    name=d_name)
  return g, d


def generator(input_shape,
              num_units: int = 16,
              activation: str = 'lrelu',
              dropout: float = 0.0,
              output_activation: str = 'linear',
              name='generator'):
  inputs = tf.keras.Input(input_shape, name='inputs')

  outputs = inputs

  for _ in range(3):
    outputs = layers.Dense(num_units)(outputs)
    outputs = utils.Activation(activation)(outputs)
    outputs = layers.Dropout(dropout)(outputs)

  outputs = layers.Dense(input_shape[-1])(outputs)
  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  assert inputs.shape[1:] == outputs.shape[1:]

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)


def discriminator(input_shape,
                  num_units: int = 16,
                  activation: str = 'lrelu',
                  dropout: float = 0.0,
                  output_activation: str = 'linear',
                  name='discriminator'):
  inputs = tf.keras.Input(input_shape, name='inputs')

  outputs = inputs

  for _ in range(3):
    outputs = layers.Dense(num_units)(outputs)
    outputs = utils.Activation(activation)(outputs)
    outputs = layers.Dropout(dropout)(outputs)

  outputs = layers.Dense(1)(outputs)
  outputs = utils.Activation(output_activation,
                             dtype=tf.float32,
                             name='output/activation')(outputs)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name=name)
