# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Borrowed from https://github.com/AIcrowd/neurips2019_disentanglement_challenge_starter_kit/blob/master/evaluate.py
Evaluation protocol to compute metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time

import gin.tf
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

from disentanglement_lib.evaluation.metrics import beta_vae  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import dci  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import downstream_task  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import factor_vae  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import irs  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import mig  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import modularity_explicitness  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import reduced_downstream_task  # pylint: disable=unused-import
from disentanglement_lib.evaluation.metrics import sap_score  # pylint: disable=unused-import
from aicrowd.metrics import unsupervised_metrics  # pylint: disable=unused-import
from aicrowd.metrics import max_corr

from disentanglement_lib.data.ground_truth import named_data
from disentanglement_lib.utils import results
from tensorflow.python.framework.errors_impl import NotFoundError


# Some more redundant code, but this allows us to not import utils_pytorch
def get_dataset_name():
    """Reads the name of the dataset from the environment variable `AICROWD_DATASET_NAME`."""
    return os.getenv("AICROWD_DATASET_NAME", "cars3d")


def evaluate_with_gin(model_dir,
                      output_dir,
                      overwrite=False,
                      gin_config_files=None,
                      gin_bindings=None):
    """Evaluate a representation based on the provided gin configuration.

    This function will set the provided gin bindings, call the evaluate()
    function and clear the gin config. Please see the evaluate() for required
    gin bindings.

    Args:
      model_dir: String with path to directory where the representation is saved.
      output_dir: String with the path where the evaluation should be saved.
      overwrite: Boolean indicating whether to overwrite output directory.
      gin_config_files: List of gin config files to load.
      gin_bindings: List of gin bindings to use.
    """
    if gin_config_files is None:
        gin_config_files = []
    if gin_bindings is None:
        gin_bindings = []
    gin.parse_config_files_and_bindings(gin_config_files, gin_bindings)
    evaluate(model_dir, output_dir, overwrite)
    gin.clear_config()


@gin.configurable(
    "evaluation", blacklist=["model_dir", "output_dir", "overwrite"])
def evaluate(model_dir,
             output_dir,
             overwrite=False,
             evaluation_fn=gin.REQUIRED,
             random_seed=gin.REQUIRED,
             name=""):
    """Loads a representation TFHub module and computes disentanglement metrics.

    Args:
      model_dir: String with path to directory where the representation function
        is saved.
      output_dir: String with the path where the results should be saved.
      overwrite: Boolean indicating whether to overwrite output directory.
      evaluation_fn: Function used to evaluate the representation (see metrics/
        for examples).
      random_seed: Integer with random seed used for training.
      name: Optional string with name of the metric (can be used to name metrics).
    """
    # We do not use the variable 'name'. Instead, it can be used to name scores
    # as it will be part of the saved gin config.
    del name

    # Delete the output directory if it already exists.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError("Directory already exists and overwrite is False.")

    # Set up time to keep track of elapsed time in results.
    experiment_timer = time.time()

    try:
        # Automatically set the proper data set if necessary. We replace the active
        # gin config as this will lead to a valid gin config file where the data set
        # is present.
        print(gin.query_parameter("dataset.name"))
        if gin.query_parameter("dataset.name") == "auto":
            # Obtain the dataset name from the gin config of the previous step.
            gin_config_file = os.path.join(model_dir, "results", "gin",
                                           "postprocess.gin")
            gin_dict = results.gin_dict(gin_config_file)
            with gin.unlock_config():
                gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace(
                    "'", ""))
        dataset = named_data.get_named_ground_truth_data()
    except NotFoundError:
        # If we did not train with disentanglement_lib, there is no "previous step",
        # so we'll have to rely on the environment variable.
        if gin.query_parameter("dataset.name") == "auto":
            with gin.unlock_config():
                gin.bind_parameter("dataset.name", get_dataset_name())
        if gin.query_parameter("dataset.name") == "dsprites_multitask":
            with gin.unlock_config():
                gin.bind_parameter("dataset.name",
                                   "dsprites_full")  # ground_truth data for dsprites_multitask is the same as dsprites
        elif gin.query_parameter("dataset.name") == "shapes3d_multitask":
            with gin.unlock_config():
                gin.bind_parameter("dataset.name",
                                   "shapes3d")  # ground_truth data for shapes3d_multitask is the same as shapes3d
        elif gin.query_parameter("dataset.name") == "mpi3d_multitask":
            with gin.unlock_config():
                gin.bind_parameter("dataset.name",
                                   "mpi3d_real")  # ground_truth data for mpi3d_multitask is the same as mpi3d_real
        dataset = named_data.get_named_ground_truth_data()

    if os.path.exists(os.path.join(model_dir, 'tfhub')):
        # Path to TFHub module of previously trained representation.
        module_path = os.path.join(model_dir, "tfhub")
        # Evaluate results with tensorflow
        results_dict = _evaluate_with_tensorflow(module_path, evaluation_fn,
                                                 dataset, random_seed)
    elif os.path.exists(os.path.join(model_dir, 'pytorch_model.pt')):
        # Path to Pytorch JIT Module of previously trained representation.
        module_path = os.path.join(model_dir, 'pytorch_model.pt')
        # Evaluate results with pytorch
        results_dict = _evaluate_with_pytorch(module_path, evaluation_fn,
                                              dataset, random_seed)
    elif os.path.exists(os.path.join(model_dir, 'python_model.dill')):
        # Path to the dilled function
        module_path = os.path.join(model_dir, 'python_model.dill')
        # Evaluate results with numpy
        results_dict = _evaluate_with_numpy(module_path, evaluation_fn,
                                            dataset, random_seed)
    else:
        print(model_dir)
        raise RuntimeError("`model_dir` must contain either a pytorch or a TFHub model.")

    # Save the results (and all previous results in the pipeline) on disk.
    original_results_dir = os.path.join(model_dir, "results")
    results_dir = os.path.join(output_dir, "results")
    results_dict["elapsed_time"] = time.time() - experiment_timer
    results.update_result_directory(results_dir, "evaluation", results_dict,
                                    original_results_dir)
    return results_dict


def _evaluate_with_tensorflow(module_path, evaluation_fn, dataset, random_seed):
    with hub.eval_function_for_module(module_path) as f:
        def _representation_function(x):
            """Computes representation vector for input images."""
            output = f(dict(images=x), signature="representation", as_dict=True)
            return np.array(output["default"])

        # Computes scores of the representation based on the evaluation_fn.
        results_dict = evaluation_fn(
            dataset,
            _representation_function,
            random_state=np.random.RandomState(random_seed))
    return results_dict


def _evaluate_with_pytorch(module_path, evalulation_fn, dataset, random_seed):
    from aicrowd import utils_pytorch
    # Load model and make a representor
    model = utils_pytorch.import_model(path=module_path)
    _representation_function = utils_pytorch.make_representor(model)
    # Evaluate score with the evaluation_fn
    results_dict = evalulation_fn(
        dataset,
        _representation_function,
        random_state=np.random.RandomState(random_seed)
    )
    # Easy peasy lemon squeezy
    return results_dict


def _evaluate_with_numpy(module_path, evalulation_fn, dataset, random_seed):
    import utils_numpy
    # Load function and make a representor
    fn = utils_numpy.import_function(path=module_path)
    _representation_function = utils_numpy.make_representor(fn)
    # Evaluate score with the evaluation_fn
    results_dict = evalulation_fn(
        dataset,
        _representation_function,
        random_state=np.random.RandomState(random_seed)
    )
    # Easy peasy lemon squeezy
    return results_dict
