#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipeline that trains a s2 model and then computes multiple scores."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags
from absl import logging
from disentanglement_lib.evaluation import evaluate
from disentanglement_lib.methods.semi_supervised import train_semi_supervised_lib
from disentanglement_lib.postprocessing import postprocess
from disentanglement_lib.utils import aggregate_results
from disentanglement_lib.utils import results
from disentanglement_lib.validation import validate
from disentanglement_lib.visualize import visualize_model
import numpy as np
from tensorflow.compat.v1 import gfile

FLAGS = flags.FLAGS

flags.DEFINE_string("output_directory", None,
                    "Output directory of experiments.")

# Model flags. If the model_dir flag is set, then that directory is used and
# training is skipped.
flags.DEFINE_string("model_dir", None, "Directory to take trained model from.")
# Otherwise, the model is trained using the gin bindings and the gin model
# config file in the gin model config folder.
flags.DEFINE_multi_string("gin_bindings", [],
                          "Newline separated list of Gin parameter bindings.")
flags.DEFINE_string("gin_model_config_dir", None,
                    "Path to directory with model configs.")
flags.DEFINE_string("gin_model_config_name", None,
                    "Filename of the model config.")

# Postprocessing and evaluation is done using glob patterns of gin configs.
flags.DEFINE_string("gin_postprocess_config_glob", None,
                    "Path to glob pattern to evaluation configs.")
flags.DEFINE_string("gin_evaluation_config_glob", None,
                    "Path to glob pattern to evaluation configs.")
flags.DEFINE_string("gin_validation_config_glob", None,
                    "Path to glob pattern to validation configs.")
# Other flags.
flags.DEFINE_integer("pipeline_seed", None,
                     "Integer with random seed for the whole pipeline.")
flags.DEFINE_integer("eval_seed", None,
                     "Integer with random seed for the evaluation pipeline.")
flags.DEFINE_boolean("overwrite", False,
                     "Whether to overwrite output directory.")

# Experiment flags.
flags.DEFINE_integer("supervised_seed", None,
                     "Integer with random seed for the supervised data.")
flags.DEFINE_integer("num_labelled_samples", None,
                     "Integer with random seed for whole pipeline.")
flags.DEFINE_float("train_percentage", None,
                   "Integer with random seed for whole pipeline.")
flags.DEFINE_string("labeller_fn", None,
                    "Labeller function for the supervised data.")


def main(unused_argv):
  # In this pipeline, we manually manage the random seeds of the different steps
  # as otherwise different training runs of the model (with different random
  # seeds) would be evaluated on exactly the same data (i.e., there would be no
  # randomness in evaluation). We use the random seed in the flag to seed a
  # random number generator from which random seeds for the different parts of
  # the pipeline are drawn. In addition, we have a fixed seed for the
  # supervision. This ensures that the supervised data is the same across
  # different models and in different validation metrics.
  random_state = np.random.RandomState(FLAGS.pipeline_seed)

  # Model training (if the model_dir is not provided.).

  # It is important that we sample the model random seed regardless whether
  # we actually train the model so later seeds are the same.

  model_random_seed = random_state.randint(2**32)
  unsupervised_data_seed = random_state.randint(2**32)
  if FLAGS.model_dir is None:
    logging.info("Training model...")
    model_dir = os.path.join(FLAGS.output_directory, "model")
    model_config_file = os.path.join(FLAGS.gin_model_config_dir,
                                     FLAGS.gin_model_config_name)
    model_bindings = [
        "model.model_seed = {}".format(model_random_seed),
        "model.unsupervised_data_seed = {}".format(unsupervised_data_seed),
        "model.name = '{}'".format(FLAGS.gin_model_config_name).replace(
            ".gin", ""),
        "model.supervised_data_seed = {}".format(FLAGS.supervised_seed),
        "model.num_labelled_samples = {}".format(FLAGS.num_labelled_samples),
        "model.train_percentage = {}".format(FLAGS.train_percentage),
        "labeller.labeller_fn = {}".format(FLAGS.labeller_fn)
    ] + FLAGS.gin_bindings
    train_semi_supervised_lib.train_with_gin(
        model_dir,
        overwrite=FLAGS.overwrite,
        gin_config_files=[model_config_file],
        gin_bindings=model_bindings)
    supervised_seed = FLAGS.supervised_seed
    num_labelled_samples = FLAGS.num_labelled_samples
    train_percentage = FLAGS.train_percentage
    labeller_fn = FLAGS.labeller_fn
  else:
    logging.info("Skipped training...")
    model_dir = os.path.join(FLAGS.model_dir, "model")
    # Automatically set the proper parameters for the supervised data. We load
    # the parameters that were used in the training step so that they will be
    # identical in the validation step.
    if FLAGS.supervised_seed is not None:
      raise ValueError(
          "The supervised_seed flag should not be specified when re-evaluating."
      )
    if FLAGS.num_labelled_samples is not None:
      raise ValueError(
          "The num_labelled_samples flag should not be specified when "
          "re-evaluating."
      )
    if FLAGS.train_percentage is not None:
      raise ValueError(
          "The train_percentage flag should not be specified when "
          "re-evaluating."
      )
    if FLAGS.labeller_fn is not None:
      raise ValueError(
          "The labeller_fn flag should not be specified when re-evaluating.")
    gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin")
    gin_dict = results.gin_dict(gin_config_file)
    supervised_seed = gin_dict["model.supervised_data_seed"]
    num_labelled_samples = gin_dict["model.num_labelled_samples"]
    train_percentage = gin_dict["model.train_percentage"]
    labeller_fn = gin_dict["labeller.labeller_fn"]

  # We visualize reconstruction, samples and latent space traversal.
  visualize_dir = os.path.join(FLAGS.output_directory, "visualizations")
  visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite)

  # We extract the different representations and save them to disk.
  evaluation_configs = sorted(gfile.Glob(FLAGS.gin_postprocess_config_glob))
  for config in evaluation_configs:
    post_name = os.path.basename(config).replace(".gin", "")
    logging.info("Extracting representation %s...", post_name)
    post_dir = os.path.join(FLAGS.output_directory, "postprocessed", post_name)
    postprocess_bindings = [
        "postprocess.random_seed = {}".format(random_state.randint(2**32)),
        "postprocess.name = '{}'".format(post_name)
    ]
    postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite,
                                     [config], postprocess_bindings)
  # Iterate through metrics for validation.
  validation_configs = sorted(gfile.Glob(FLAGS.gin_validation_config_glob))
  for config in evaluation_configs:

    post_name = os.path.basename(config).replace(".gin", "")
    post_dir = os.path.join(FLAGS.output_directory, "postprocessed", post_name)
    # Now, we compute all the specified scores.
    for gin_eval_config in validation_configs:
      metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
      metric_bindings = []
      metric_name = metric_name + "_validation"
      if metric_name != "mig_validation":
        metric_bindings += [
            "{}.train_percentage = {}".format(metric_name,
                                              train_percentage)
        ]
      metric_name = "{}_{}_{}".format(metric_name, supervised_seed,
                                      num_labelled_samples)
      logging.info("Computing metric '%s' on '%s'...", metric_name, post_name)
      metric_dir = os.path.join(FLAGS.output_directory, "validation", post_name,
                                metric_name)
      metric_bindings += [
          "validation.random_seed = {}".format(supervised_seed),
          "validation.name = '{}'".format(metric_name),
          "validation.num_labelled_samples = {}".format(
              num_labelled_samples),
          "labeller.labeller_fn = {}".format(labeller_fn)
      ]
      validate.validate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                 [gin_eval_config], metric_bindings)
  # Set the random state for the evaluation.
  random_state_eval = np.random.RandomState(FLAGS.eval_seed)
  # Iterate through metrics for test.
  metric_configs = sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob))
  for config in evaluation_configs:
    post_name = os.path.basename(config).replace(".gin", "")
    post_dir = os.path.join(FLAGS.output_directory, "postprocessed",
                            post_name)
    # Now, we compute all the specified scores.
    for gin_eval_config in metric_configs:
      metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
      metric_name = "{}_{}".format(metric_name, supervised_seed)
      logging.info("Computing metric '%s' on '%s'...", metric_name, post_name)
      metric_dir = os.path.join(FLAGS.output_directory, "metrics", post_name,
                                metric_name)
      eval_bindings = [
          "evaluation.random_seed = {}".format(
              random_state_eval.randint(2**32)),
          "evaluation.name = '{}'".format(metric_name)
      ]
      evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                 [gin_eval_config], eval_bindings)
  # Aggregate all the results in a single json file per model.
  result_dir = os.path.join(FLAGS.output_directory, "metrics",
                            "aggregated_results.json")
  pattern = os.path.join(FLAGS.output_directory,
                         "metrics/*/*/results/aggregate/evaluation.json")
  aggregate_results.aggregate_results_to_json(pattern, result_dir)


if __name__ == "__main__":
  app.run(main)
