#!python
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipeline that trains a model and then computes multiple scores."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from absl import app
from absl import flags
from absl import logging
from disentanglement_lib.evaluation import evaluate
from disentanglement_lib.methods.unsupervised import train
from disentanglement_lib.postprocessing import postprocess
from disentanglement_lib.utils import aggregate_results
from disentanglement_lib.visualize import visualize_model
import numpy as np
from six.moves import range
from tensorflow.compat.v1 import gfile



FLAGS = flags.FLAGS

flags.DEFINE_string("output_directory", None,
                    "Output directory of experiments.")

# Model flags. If the model_dir flag is set, then that directory is used and
# training is skipped.
flags.DEFINE_string("model_dir", None, "Directory to take trained model from.")
# Otherwise, the model is trained using the gin bindings and the gin model
# config file in the gin model config folder.
flags.DEFINE_multi_string("gin_bindings", [],
                          "Newline separated list of Gin parameter bindings.")
flags.DEFINE_string("gin_model_config_dir", None,
                    "Path to directory with model configs.")
flags.DEFINE_string("gin_model_config_name", None,
                    "Filename of the model config.")

# Postprocessing and evaluation is done using glob patterns of gin configs.
flags.DEFINE_string("gin_postprocess_config_glob", None,
                    "Path to glob pattern to evaluation configs.")
flags.DEFINE_string("gin_evaluation_config_glob", None,
                    "Path to glob pattern to evaluation configs.")
flags.DEFINE_integer("num_eval", 1,
                     "Number of times the evaluation measures are computed.")
flags.DEFINE_multi_string("scores_list", [],
                          "List of scores to be evaluate multiple times.")

flags.DEFINE_multi_string("evaluation_num_samples_train", [],
                          "List of sample sizes for the evaluation.")

# Other flags.
flags.DEFINE_integer("random_seed", None,
                     "Integer with random seed for whole pipeline.")

flags.DEFINE_boolean("overwrite", False,
                     "Whether to overwrite output directory.")


def main(unused_argv):
  if FLAGS.scores_list and not FLAGS.evaluation_num_samples_train:
    raise ValueError("How many samples should be used to compute the scores in"
                     " scores_list must be specified.")
  # In this pipeline, we manually manage the random seeds of the different steps
  # as otherwise different training runs of the model (with different random
  # seeds) would be evaluated on exactly the same data (i.e., there would be no
  # randomness in evaluation). We use the random seed in the flag to seed a
  # random number generator from which random seeds for the different parts of
  # the pipeline are drawn.
  random_state = np.random.RandomState(FLAGS.random_seed)

  # Model training (if the model_dir is not provided.).

  # It is important that we sample the model random seed regardless whether
  # we actually train the model so later seeds are the same.
  model_random_seed = random_state.randint(2**32)
  if FLAGS.model_dir is None:
    logging.info("Training model...")
    model_dir = os.path.join(FLAGS.output_directory, "model")
    model_config_file = os.path.join(FLAGS.gin_model_config_dir,
                                     FLAGS.gin_model_config_name)
    model_bindings = [
        "model.random_seed = {}".format(model_random_seed),
        "model.name = '{}'".format(FLAGS.gin_model_config_name).replace(
            ".gin", "")
    ] + FLAGS.gin_bindings
    train.train_with_gin(model_dir, FLAGS.overwrite, [model_config_file],
                         model_bindings)
  else:
    logging.info("Skipped training...")
    model_dir = os.path.join(FLAGS.model_dir, "model")

  # We visualize reconstruction, samples and latent space traversal.
  visualize_dir = os.path.join(FLAGS.output_directory, "visualizations")
  visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite)

  # We extract the different representations and save them to disk.
  evaluation_configs = sorted(gfile.Glob(FLAGS.gin_postprocess_config_glob))
  for config in evaluation_configs:
    post_name = os.path.basename(config).replace(".gin", "")
    logging.info("Extracting representation %s...", post_name)
    post_dir = os.path.join(FLAGS.output_directory, "postprocessed", post_name)
    postprocess_bindings = [
        "postprocess.random_seed = {}".format(random_state.randint(2**32)),
        "postprocess.name = '{}'".format(post_name)
    ]
    postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite,
                                     [config], postprocess_bindings)

  # Iterate through metrics.
  metric_configs = sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob))

  for iternation_num in range(FLAGS.num_eval):
    for config in evaluation_configs:
      post_name = os.path.basename(config).replace(".gin", "")
      post_dir = os.path.join(FLAGS.output_directory, "postprocessed",
                              post_name)
      # Now, we compute all the specified scores.
      for gin_eval_config in metric_configs:
        metric_name_gin = os.path.basename(gin_eval_config).replace(
            ".gin", "")
        metric_name = "{}_{}".format(metric_name_gin, iternation_num)
        logging.info("Computing metric '%s' on '%s'...",
                     metric_name, post_name)

        # Only the scores in FLAGS.scores_list need to be ran for multiple
        # sample sizes.
        if metric_name_gin in FLAGS.scores_list:
          for num_samples in FLAGS.evaluation_num_samples_train:
            metric_name = "{}_{}_{}".format(metric_name_gin, num_samples,
                                            iternation_num)
            metric_dir = os.path.join(FLAGS.output_directory, "metrics",
                                      post_name, metric_name)
            eval_bindings = [
                "{}.num_train = {}".format(metric_name_gin, int(num_samples)),
                "evaluation.name = '{}'".format(metric_name),
                "evaluation.random_seed = {}".format(
                    random_state.randint(2**32))]
            evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                       [gin_eval_config], eval_bindings)
        else:
          metric_dir = os.path.join(FLAGS.output_directory, "metrics",
                                    post_name, metric_name)
          eval_bindings = [
              "evaluation.name = '{}'".format(metric_name)]
          evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                     [gin_eval_config], eval_bindings)
  # Aggregate all the results in a single json file per model.
  result_dir = os.path.join(FLAGS.output_directory, "metrics",
                            "aggregated_results.json")
  pattern = os.path.join(FLAGS.output_directory,
                         "metrics/*/*/results/aggregate/evaluation.json")
  aggregate_results.aggregate_results_to_json(pattern, result_dir)


if __name__ == "__main__":
  app.run(main)
