from .dci import dci
from .explicitness import explicitness
from .independent import *
from .irs import irs
from .mig import mig
from .modularity import modularity
from .sap import sap
from .z_diff import z_diff
from .z_min_var import z_min_var


def obtain_metrics_minimality(factors, codes):
    reg_type = "gradient_booster"
    n_samples = factors.shape[0]
    beta_vae = z_diff(factors, codes, nb_training=(3*n_samples)//4, nb_eval=n_samples//4)
    factor_vae = z_min_var(factors, codes, nb_training=(3*n_samples)//4, nb_eval=n_samples//4)
    dci_val = dci(factors, codes)
    #irs_val = irs(factors, codes)
    modularity_val = modularity(factors, codes)
    fi_val = calc_factors_invariance(factors, codes, reg_type=reg_type)
    minimality_val = calc_minimality(factors, codes, reg_type=reg_type)
    return {
        'beta_vae':             beta_vae,
        'factor_vae':           factor_vae,
        'modularity':           modularity_val,
        'dci_disentanglement':  dci_val[0],
        #'irs':                  irs_val,
        #'ind_fi':               fi_val,
        'ind_minimality':       minimality_val,
    }


def obtain_metrics_sufficiency(factors, codes):
    reg_type = "gradient_booster"
    mig_val = mig(factors, codes)
    sap_val = sap(factors, codes)
    dci_val = dci(factors, codes)
    explicitness_val = explicitness(factors, codes)
    #irs_val = irs(factors, codes)
    ri_val = calc_representations_invariance(factors, codes, reg_type=reg_type)
    sufficiency_val = calc_sufficiency(factors, codes, reg_type=reg_type)
    return {
        'mig':                  mig_val,
        'sap':                  sap_val,
        'dci_completeness':     dci_val[1],
        'explicitness':         explicitness_val,
        #'irs':                  irs_val,
        #'ind_ri':               ri_val,
        'ind_sufficiency':      sufficiency_val,
    }


def obtain_metrics_fre(factors, codes):
    reg_type = "gradient_booster"
    fi_val = calc_factors_invariance(factors, codes, reg_type=reg_type)
    ri_val = calc_representations_invariance(factors, codes, reg_type=reg_type)
    e_val = calc_explicitness(factors, codes, reg_type=reg_type)
    return {
        'ind_fi':   fi_val,
        'ind_ri':   ri_val,
        'e_val':    e_val,
    }