import os
import sys

import gin.tf
import numpy as np

PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
DIS_PROJECT_PATH = os.path.join(os.path.dirname(PROJECT_PATH), "disentanglement_lib")
sys.path.append(PROJECT_PATH)
sys.path.append(DIS_PROJECT_PATH)

print("PROJECT_PATH", PROJECT_PATH)
print("DIS_PROJECT_PATH", DIS_PROJECT_PATH)

GIN_MIG = ["/home/disentanglement_lib/disentanglement_lib/config/disentangling_everything/mig.gin"]
GIN_FACTOR = [
    "/home/disentanglement_lib/disentanglement_lib/config/disentangling_everything/factor_vae_metric.gin"]
GIN_MOD = \
    ["/home/disentanglement_lib/disentanglement_lib/config/disentangling_everything/modularity_explicitness.gin"]
from disentanglement_lib.evaluation.metrics import beta_vae, factor_vae, dci, mig, sap_score, modularity_explicitness


def evaluate_metrics(representation_function, ground_truth_data, repetitions, num_train, num_eval, save_path,
                     batch_size=16,
                     calculate_beta=True,
                     calculate_factor=True,
                     calculate_sap=True,
                     calculate_dci=True,
                     calculate_mig=True,
                     calculate_mod=True):
    # Evaluate Beta score
    if calculate_beta:
        beta_scores = np.zeros(repetitions)
        print("Estimating Beta Score")
        for i in range(repetitions):
            results_dictionary = beta_vae.compute_beta_vae_sklearn(ground_truth_data,
                                                                   representation_function,
                                                                   np.random,
                                                                   artifact_dir=None,
                                                                   batch_size=batch_size,
                                                                   num_train=num_train,
                                                                   num_eval=num_eval)
            beta_scores[i] = results_dictionary["eval_accuracy"]
        beta = np.array([np.mean(beta_scores), np.std(beta_scores)])
        np.save(os.path.join(save_path, "beta.npy"), beta)
        print("Beta score {} pm {}".format(beta[0], beta[1]))
    else:
        beta = None

    # Evaluate Factor
    if calculate_factor:
        print("Estimating Factor Score")
        gin.clear_config()
        gin.parse_config_files_and_bindings(GIN_FACTOR, [])
        factor_scores = np.zeros(repetitions)
        for i in range(repetitions):
            results_dictionary = factor_vae.compute_factor_vae(ground_truth_data,
                                                               representation_function,
                                                               np.random,
                                                               artifact_dir=None,
                                                               batch_size=batch_size,
                                                               num_variance_estimate=gin.REQUIRED)
            factor_scores[i] = results_dictionary["eval_accuracy"]
        gin.clear_config()
        factor_value = np.array([np.mean(factor_scores), np.std(factor_scores)])
        np.save(os.path.join(save_path, "factor.npy"), factor_value)
        print("Factor score {} pm {}".format(factor_value[0], factor_value[1]))
    else:
        factor_value = None
    if calculate_sap:
        # Evaluate SAP score
        print("Estimating SAP Score")
        sap_scores = np.zeros(repetitions)
        for i in range(repetitions):
            results_dictionary = sap_score.compute_sap(ground_truth_data,
                                                       representation_function,
                                                       np.random,
                                                       num_train=num_train,
                                                       num_test=num_eval,
                                                       continuous_factors=False)
            sap_scores[i] = results_dictionary["SAP_score"]
        sap = np.array([np.mean(sap_scores), np.std(sap_scores)])
        np.save(os.path.join(save_path, "sap.npy"), sap)
        print("SAP score {} pm {}".format(sap[0], sap[1]))
    else:
        sap = None

    # Evaluation DCI
    if calculate_dci:
        print("Estimating DCI Score")
        dci_scores = np.zeros(repetitions)
        comp_scores = np.zeros(repetitions)
        for i in range(repetitions):
            results_dictionary = dci.compute_dci(ground_truth_data, representation_function, np.random,
                                                 num_train=num_train,
                                                 num_test=num_eval,
                                                 batch_size=batch_size)
            dci_scores[i] = results_dictionary["disentanglement"]
            comp_scores[i] = results_dictionary["completeness"]
        dci_value = np.array([np.mean(dci_scores), np.std(dci_scores)])
        comp_value = np.array([np.mean(comp_scores), np.std(comp_scores)])
        np.save(os.path.join(save_path, "dci.npy"), dci_value)
        np.save(os.path.join(save_path, "complete.npy"), comp_value)
        print("DCI score {} pm {}".format(dci_value[0], dci_value[1]))
        print("Completeness score {} pm {}".format(comp_value[0], comp_value[1]))
    else:
        dci_value = None

    # Evaluate MIG
    if calculate_mig:
        print("Estimating MIG Score")
        gin.clear_config()
        gin.parse_config_files_and_bindings(GIN_MIG, [])
        mig_scores = np.zeros(repetitions)
        for i in range(repetitions):
            results_dictionary = mig.compute_mig(ground_truth_data,
                                                 representation_function,
                                                 np.random,
                                                 num_train=num_train,
                                                 batch_size=batch_size)
            mig_scores[i] = results_dictionary["discrete_mig"]
        gin.clear_config()
        mig_value = np.array([np.mean(mig_scores), np.std(mig_scores)])
        np.save(os.path.join(save_path, "mig.npy"), mig_value)
        print("MIG score {} pm {}".format(mig_value[0], mig_value[1]))
    else:
        mig_value = None

    # Evaluate modularity
    if calculate_mod:
        print("Estimating MOD Score")
        gin.clear_config()
        gin.parse_config_files_and_bindings(GIN_MOD, [])
        modularity_scores = np.zeros(repetitions)
        for i in range(repetitions):
            results_dictionary = modularity_explicitness.compute_modularity_explicitness(ground_truth_data,
                                                                                         representation_function,
                                                                                         np.random,
                                                                                         batch_size=batch_size)
            modularity_scores[i] = results_dictionary["modularity_score"]
        gin.clear_config()
        modularity_value = np.array([np.mean(modularity_scores), np.std(modularity_scores)])
        np.save(os.path.join(save_path, "mod.npy"), modularity_value)
        print("MOD score {} pm {}".format(modularity_value[0], modularity_value[1]))
    else:
        modularity_value = None

    return beta, factor_value, sap, dci_value, mig_value, modularity_value
