import torch
from src.utils.seed import set_seed
from disent.metrics.beta import compute_beta_vae
from disent.metrics.fvm import FactorVAEMetric
from disent.metrics.sap import SAP
from disent.metrics.dci import metric_dci
from disent.metrics.mig import MIGMetric
from disent.metrics.utils import latents_and_factors

import pdb


def estimate_all_distenglement(
    dataset,
    model,
    loss_fn,
    continuous_factors,
    args,
    results,
):
    with torch.no_grad():
        model.eval()

        set_seed(args)
        disent_result = compute_beta_vae(
            dataset,
            model=model,
            batch_size=64,
            num_train=500,
            num_eval=50,
            loss_fn=loss_fn,
            args=args,
        )

        results["beta_vae"] = disent_result


        set_seed(args)
        disent_result = FactorVAEMetric(
            dataset,
            model=model,
            batch_size=100,
            num_train=800,
            loss_fn=loss_fn,
            args=args,
        )
        results["factor_disent"] = disent_result["disentanglement_accuracy"]

        set_seed(args)
        train_latents, train_factors = latents_and_factors(
            dataset=dataset,
            model=model,
            batch_size=64,
            interation=100,
            loss_fn=loss_fn,
            args=args,
        )
        test_latents, test_factors = latents_and_factors(
            dataset=dataset,
            model=model,
            batch_size=64,
            interation=50,
            loss_fn=loss_fn,
            args=args,
        )


        set_seed(args)
        disent_result = SAP(
            train_latents,
            train_factors,
            test_latents,
            test_factors,
            args,
            continuous_factors=continuous_factors,
        )
        results["sap"] = disent_result

        set_seed(args)
        disent_result = metric_dci(
            train_latents,
            train_factors,
            test_latents,
            test_factors,
            args,
            continuous_factors=continuous_factors,
        )
        results["dci_disent"] = disent_result[2]
        results["dci_comple"] = disent_result[3]
        # (args)
        disent_result = MIGMetric(train_latents, train_factors)
        results["mig"] = disent_result

    return results
