#!python
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipeline to train and evaluate weakly-supervised models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags
from absl import logging
from disentanglement_lib.evaluation import evaluate
from disentanglement_lib.methods.weak import train_weak_lib
from disentanglement_lib.postprocessing import postprocess
from disentanglement_lib.visualize import visualize_model
import numpy as np
from tensorflow import gfile


FLAGS = flags.FLAGS

flags.DEFINE_string("output_directory", None,
                    "Output directory of experiments.")

# Model flags. If the model_dir flag is set, then that directory is used and
# training is skipped.
flags.DEFINE_string("model_dir", None, "Directory to take trained model from.")
# Otherwise, the model is trained using the gin bindings and the gin model
# config file in the gin model config folder.
flags.DEFINE_multi_string("gin_bindings", [],
                          "Newline separated list of Gin parameter bindings.")
flags.DEFINE_string("gin_model_config_dir", None,
                    "Path to directory with model configs.")
flags.DEFINE_string("gin_model_config_name", None,
                    "Filename of the model config.")

# Postprocessing and evaluation is done using glob patterns of gin configs.
flags.DEFINE_string("gin_postprocess_config_glob", None,
                    "Path to glob pattern to evaluation configs.")
flags.DEFINE_string("gin_evaluation_config_glob", None,
                    "Path to glob pattern to evaluation configs.")
flags.DEFINE_string("gin_validation_config_glob", None,
                    "Path to glob pattern to validation configs.")
# Other flags.
flags.DEFINE_integer("pipeline_seed", None,
                     "Integer with random seed for whole pipeline.")
flags.DEFINE_integer("eval_seed", None,
                     "Integer with random seed for whole pipeline.")
flags.DEFINE_boolean("overwrite", False,
                     "Whether to overwrite output directory.")

# Experiment flags.
flags.DEFINE_integer("supervised_seed", None,
                     "Integer with random seed for the supervised data.")
flags.DEFINE_integer("num_labelled_samples", None,
                     "Integer with random seed for whole pipeline.")
flags.DEFINE_float("train_percentage", None,
                   "Integer with random seed for whole pipeline.")
flags.DEFINE_string("labeller_fn", None,
                    "Labeller function for the supervised data.")


def main(unused_argv):
  # In this pipeline, we manually manage the random seeds of the different steps
  # as otherwise different training runs of the model (with different random
  # seeds) would be evaluated on exactly the same data (i.e., there would be no
  # randomness in evaluation). We use the random seed in the flag to seed a
  # random number generator from which random seeds for the different parts of
  # the pipeline are drawn. In addition, we have a fixed seed for the
  # supervision. This ensures that the supervised data is the same across
  # different models and in different validation metrics.
  random_state = np.random.RandomState(FLAGS.pipeline_seed)

  # Model training (if the model_dir is not provided.).

  # It is important that we sample the model random seed regardless whether
  # we actually train the model so later seeds are the same.

  if FLAGS.model_dir is None:
    logging.info("Training model...")
    model_dir = os.path.join(FLAGS.output_directory, "model")
    model_config_file = os.path.join(FLAGS.gin_model_config_dir,
                                     FLAGS.gin_model_config_name)
    model_bindings = [
        "model.random_seed = {}".format(random_state.randint(2**32)),
        "model.name = '{}'".format(FLAGS.gin_model_config_name).replace(
            ".gin", "")
    ] + FLAGS.gin_bindings
    train_weak_lib.train_with_gin(model_dir,
                                  FLAGS.overwrite,
                                  [model_config_file],
                                  model_bindings)
  else:
    logging.info("Skipped training...")
    model_dir = os.path.join(FLAGS.model_dir, "model")

  # We visualize reconstruction, samples and latent space traversal.
  visualize_dir = os.path.join(FLAGS.output_directory, "visualizations")
  visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite)

  # We extract the different representations and save them to disk.
  evaluation_configs = sorted(gfile.Glob(FLAGS.gin_postprocess_config_glob))
  for config in evaluation_configs:
    post_name = os.path.basename(config).replace(".gin", "")
    logging.info("Extracting representation %s...", post_name)
    post_dir = os.path.join(FLAGS.output_directory, "postprocessed", post_name)
    postprocess_bindings = [
        "postprocess.random_seed = {}".format(random_state.randint(2**32)),
        "postprocess.name = '{}'".format(post_name)
    ]
    postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite,
                                     [config], postprocess_bindings)

  # Iterate through metrics for test.
  print(FLAGS.gin_evaluation_config_glob)
  metric_configs = sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob))
  print(metric_configs)

  for config in evaluation_configs:
    post_name = os.path.basename(config).replace(".gin", "")
    post_dir = os.path.join(FLAGS.output_directory, "postprocessed",
                            post_name)
    # Now, we compute all the specified scores.
    for gin_eval_config in metric_configs:
      metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
      if "unsupervised" not in metric_name:
        metric_name = "{}_{}".format(metric_name, 0)
        logging.info("Computing metric '%s' on '%s'...",
                     metric_name, post_name)
        metric_dir = os.path.join(FLAGS.output_directory, "metrics",
                                  post_name, metric_name)
        eval_bindings = [
            "evaluation.random_seed = {}".format(random_state.randint(2**32)),
            "evaluation.name = '{}'".format(metric_name)
        ]
        evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                   [gin_eval_config], eval_bindings)


if __name__ == "__main__":
  app.run(main)
