import os
from cg.src.analysis_tools.dci_utils import latents_and_factors, metric_dci, hinton


def save_dci_matrix(dataset,
                    model,
                    batch_size,
                    iteration,
                    loss_fn,
                    matrix_dir,
                    args):

    train_latents, train_factors = latents_and_factors(dataset=dataset,
                                                       model=model,
                                                       batch_size=batch_size,
                                                       iteration=iteration,
                                                       loss_fn=loss_fn
                                                       )

    test_latents, test_factors = latents_and_factors(dataset=dataset,
                                                       model=model,
                                                       batch_size=batch_size,
                                                       iteration=iteration,
                                                       loss_fn=loss_fn
                                                       )

    dci = metric_dci(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        args,
        continuous_factors=False,
    )

    importance_matrix = dci[4]
    imgs_dir = os.path.join(matrix_dir, 'DCI_Hinton.png')
    hinton(importance_matrix, dir=imgs_dir)

    return