import os
import tensorflow as tf

from cyclegan.utils import utils
from cyclegan.utils.tensorboard import Summary

_MODELS = dict()


def register(name):

  def add_to_dict(fn):
    global _MODELS
    _MODELS[name] = fn
    return fn

  return add_to_dict


def get_models(args,
               strategy: tf.distribute.Strategy,
               summary: Summary = None,
               g_name: str = 'generator',
               d_name: str = 'discriminator',
               write_summary: bool = True):
  if args.model not in _MODELS:
    raise KeyError(f'model {args.model} not found')

  with strategy.scope():
    generator, discriminator = _MODELS[args.model](args,
                                                   g_name=g_name,
                                                   d_name=d_name)

  if summary is not None:
    summary.scalar(f'model/trainable_parameters/{g_name}',
                   utils.count_trainable_params(generator))
    summary.scalar(f'model/trainable_parameters/{d_name}',
                   utils.count_trainable_params(discriminator))

  if write_summary:
    gen_summary = utils.model_summary(args, generator)
    disc_summary = utils.model_summary(args, discriminator)

    if args.verbose == 2:
      print(gen_summary)
      print('')
      print(disc_summary)

  return generator, discriminator
