import torch
from dl.src.utils.utils import set_seed
from dl.src.dl_metrics.beta import compute_beta_vae
from dl.src.dl_metrics.fvm import FactorVAEMetric
from dl.src.dl_metrics.sap import SAP
from dl.src.dl_metrics.dci import metric_dci
from dl.src.dl_metrics.mig import MIGMetric
from dl.src.dl_metrics.utils import latents_and_factors

import pdb


def estimate_all_distenglement(
    dataset,
    model,
    loss_fn,
    continuous_factors,
    args,
    results,
):
    with torch.no_grad():
        model.eval()

        set_seed(args)
        disent_result = compute_beta_vae(
            dataset,
            model=model,
            batch_size=64,
            num_train=500,
            num_eval=50,
            loss_fn=loss_fn,
            args=args,
        )

        results["beta_vae"] = disent_result


        set_seed(args)
        disent_result = FactorVAEMetric(
            dataset,
            model=model,
            batch_size=100,
            num_train=800,
            loss_fn=loss_fn,
            args=args,
        )
        results["factor_disent"] = disent_result["disentanglement_accuracy"]

        set_seed(args)
        train_latents, train_factors = latents_and_factors(
            dataset=dataset,
            model=model,
            batch_size=64,
            interation=100,
            loss_fn=loss_fn,
            args=args,
        )
        test_latents, test_factors = latents_and_factors(
            dataset=dataset,
            model=model,
            batch_size=64,
            interation=50,
            loss_fn=loss_fn,
            args=args,
        )

        # MPI3D dataset인 경우에는 SAP는 skip 하도록 작성할 것 ########
        # if args.dataset == "mpi3d":
        #     results["sap"] = 0.0
        #     results["dci"] = {}
        #     results["dci"]["train_err"] = 0.0
        #     results["dci"]["test_err"] = 0.0
        #     results["dci"]["disent"] = 0.0
        #     results["dci"]["comple"] = 0.0
        #     results["mig"] = 0.0
        # else:
        set_seed(args)
        disent_result = SAP(
            train_latents,
            train_factors,
            test_latents,
            test_factors,
            args,
            continuous_factors=continuous_factors,
        )
        results["sap"] = disent_result

        set_seed(args)
        disent_result = metric_dci(
            train_latents,
            train_factors,
            test_latents,
            test_factors,
            args,
            continuous_factors=continuous_factors,
        )
        results["dci_disent"] = disent_result[2]
        results["dci_comple"] = disent_result[3]
        # results["dci"]["train_err"] = disent_result[0]
        # results["dci"]["test_err"] = disent_result[1]
        # results["dci"]["disent"] = disent_result[2]
        # results["dci"]["comple"] = disent_result[3]

        # (args)
        disent_result = MIGMetric(train_latents, train_factors)
        results["mig"] = disent_result

    return results
