import os
import json
import numpy as np

from numpy import ndarray
from utils.wandb_logger import *
from datasets.utils.base_dataset import BaseDataset
from models.mnistdpl import MnistDPL
from enum import Enum
from utils.dpl_loss import ADDMNIST_DPL
from typing import List
import wandb
from utils.metrics import (
    evaluate_metrics,
    evaluate_mix,
    get_alpha,
    get_alpha_single,
    get_concept_probability,
    get_concept_probability_mcdropout,
    get_concept_probability_ensemble,
    get_concept_probability_laplace,
    get_concept_probability_factorized_mcdropout,
    get_concept_probability_factorized_ensemble,
    get_concept_probability_factorized_laplace,
    expected_calibration_error,
    mean_entropy,
    mean_variance,
    class_mean_entropy,
    class_mean_variance,
    expected_calibration_error_by_concept,
    laplace_p_c_x_distance,
    ensemble_p_c_x_distance,
    mcdropout_p_c_x_distance,
    concept_accuracy,
    world_accuracy
)
from utils.visualization import (
    produce_confusion_matrix,
    produce_alpha_matrix,
    produce_calibration_curve,
    produce_bar_plot,
    produce_scatter_multi_class,
    plot_grouped_entropies
)
from utils.bayes import (
    montecarlo_dropout,
    deep_ensemble,
    ensemble_predict,
    activate_dropout,
    laplace_approximation,
    laplace_prediction,
    recover_predictions_from_laplace,
    laplace_single_prediction
)
from utils import fprint
from utils.checkpoint import load_checkpoint, get_model_name
from itertools import product


class IllegalArgumentError(ValueError):
    pass

class ECEMODE(Enum):
    # Mode for computing the ECE
    WHOLE = 1
    FILTERED_BY_CONCEPT = 2

class EVALUATION_TYPE(Enum):
    NORMAL = 'frequentist'
    LAPLACE = 'laplace'
    MC_DROPOUT = 'mcdropout'
    BEARS = 'bears'
    ENSEMBLE = 'ensemble'

class REDUCED_EVALUATION_TYPE(Enum):
    NORMAL = 'frequentist'
    MC_DROPOUT = 'mcdropout'
    BEARS = 'bears'
    ENSEMBLE = 'ensemble'

def euclidean_distance(w1, w2):
    return torch.sqrt(sum(torch.sum((p1 - p2)**2) for p1, p2 in zip(w1, w2)))

def fprint_weights_distance(
    original_weights,
    ensemble,
    method_1,
    method_2
):
    distance = 0
    for model in ensemble:
        model_weights = [param.data.clone() for param in model.parameters()]
        distance += euclidean_distance(original_weights, model_weights)
    distance = distance / len(ensemble)
    fprint(f"Euclidean Distance between {method_1} and {method_2}: ", distance.item())

def fprint_ensemble_distance(
    ensemble
):
    distance = 0
    for i in range(len(ensemble) - 1):
        original_weights = [param.data.clone() for param in ensemble[i].parameters()]
        for j in range(i + 1, len(ensemble)):
            model_weights = [param.data.clone() for param in ensemble[j].parameters()]
            distance = euclidean_distance(original_weights, model_weights)
            fprint(f"Euclidean Distance between #{i} and #{j}: ", distance.item())

def print_p_c_given_x_distance(
    model,
    laplace_model,
    ensemble,
    test_loader,
    recover_predictions_from_laplace,
    activate_dropout,
    type: str,
    num_ensembles: int
) -> None:
    dist = None
    if type == EVALUATION_TYPE.LAPLACE.value:
        dist = laplace_p_c_x_distance(laplace_model, test_loader, num_ensembles, recover_predictions_from_laplace, model.nr_classes, model.n_facts)
    elif type == EVALUATION_TYPE.MC_DROPOUT.value:
        dist = mcdropout_p_c_x_distance(model, test_loader, activate_dropout, num_ensembles)
    elif type == EVALUATION_TYPE.BEARS.value or type == EVALUATION_TYPE.ENSEMBLE.value:
        dist = ensemble_p_c_x_distance(ensemble, test_loader)
    fprint(f"Mean P(C|X) for {type} distance L2 is {dist}")

def print_distance(
    model_or_s,
    test_loader,
    recover_predictions_from_laplace,
    n_ensembles,
    n_classes,
    n_facts,
    type: str
):
    if type == EVALUATION_TYPE.BEARS.value or type == EVALUATION_TYPE.ENSEMBLE.value:
        dist = ensemble_p_c_x_distance(
            model_or_s, 
            test_loader
        )
    elif type == EVALUATION_TYPE.LAPLACE.value:
        dist = laplace_p_c_x_distance(
            model_or_s, 
            test_loader, 
            n_ensembles, 
            recover_predictions_from_laplace, 
            n_classes, 
            n_facts
        )
    else: 
        dist = mcdropout_p_c_x_distance(
            model_or_s,
            test_loader,
            activate_dropout,
            n_ensembles
        )
    fprint(f'{type} distance: {dist}')

def print_metrics(
    y_true: ndarray,
    y_pred: ndarray,
    c_true: ndarray,
    c_pred: ndarray,
    p_cs_all: ndarray,
    n_facts: int,
    mode,
):
    yac, yf1 = evaluate_mix(y_true, y_pred)
    cac, cf1 = evaluate_mix(c_true, c_pred)
    
    if mode != EVALUATION_TYPE.NORMAL.value:
        n_facts = n_facts**2
    
    h_c = mean_entropy(p_cs_all, n_facts)
    # var_c = mean_variance(p_cs_all, n_facts)

    fprint(f"Performances:")
    fprint(f"Concepts:\n    ACC: {cac}, F1: {cf1}")
    fprint(f"Labels:\n      ACC: {yac}, F1: {yf1}")
    fprint(f"Entropy:\n     H(C): {h_c}")
    # fprint(f"Variance:\n    Var(C): {var_c}")

    return h_c, yac, cac, cf1, yf1


def produce_h_c_given_y(
    p_cs_all: ndarray, y_true: ndarray, nr_classes: int, mode: str, suffix: str
) -> None:
    h_c_given_y = class_mean_entropy(
        p_cs_all, np.concatenate((y_true, y_true)), nr_classes
    )
    produce_bar_plot(
        h_c_given_y, "Groundtruth class", "Entropy", "H(C|Y)", f"h_c_given_y_{mode}{suffix}", True
    )

def produce_var_c_given_y(
    p_cs_all: ndarray, y_true: ndarray, nr_classes: int, mode: str, suffix: str
) -> None:
    var_c_given_y = class_mean_variance(
        p_cs_all, np.concatenate((y_true, y_true)), nr_classes
    )
    produce_bar_plot(
        var_c_given_y, "Groundtruth class", "Variance", "Var(C|Y)", f"var_c_given_y_{mode}{suffix}", True
    )

def compute_concept_factorized_entropy(
    c_fact_1: ndarray,
    c_fact_2: ndarray,
    p_w_x: ndarray,
):
    def ova_entropy(p: ndarray, c: int):
        import math
        p += 1e-5
        p /= 1 + (p.shape[0]*1e-5)

        positive = p[c] * math.log2(p[c])

        # mask to exclude index of the world
        mask = np.arange(len(p)) != c

        p_against_c = np.sum(p[mask])

        negative = p_against_c * math.log2(p_against_c)

        return -(positive + negative)

    conditional_entropies = { 'c1': list(), 'c2': list(), '(c1, c2)': list(), 'c': list() }
    
    c_fact_stacked = np.vstack([c_fact_1, c_fact_2])

    for c_fact, key in zip([c_fact_1, c_fact_2, c_fact_stacked, p_w_x], ['c1', 'c2', 'c', '(c1, c2)']):
        for c in range(c_fact.shape[1]):
            result = np.apply_along_axis(ova_entropy, axis=1, arr=c_fact, c=c)
            conditional_entropies[key].append(np.mean(result))

    return conditional_entropies

def compute_entropy_per_concept(
    c_fact_stacked: ndarray,
    c_true: ndarray,
):
    def ova_entropy(p: ndarray, c: int):
        import math
        p += 1e-5
        p /= 1 + (p.shape[0]*1e-5)

        positive = p[c] * math.log2(p[c])

        # mask to exclude index of the world
        mask = np.arange(len(p)) != c

        p_against_c = np.sum(p[mask])

        negative = p_against_c * math.log2(p_against_c)

        return -(positive + negative)
    
    def entropy(p: ndarray):
        entropy = -np.sum(p * np.log2(p))

        # Normalize entropy
        vector_size = len(p)
        normalized_entropy = entropy / np.log2(vector_size)

        return normalized_entropy

    conditional_entropies = { 'c_ova_filtered': list(), 'c_all_filtered': list() }

    for c in range(c_fact_stacked.shape[1]):
        indices = np.where(c_true == c)[0]
        c_fact_filtered = c_fact_stacked[indices]

        result = np.apply_along_axis(ova_entropy, axis=1, arr=c_fact_filtered, c=c)
        conditional_entropies['c_ova_filtered'].append(np.mean(result))

        result = np.apply_along_axis(entropy, axis=1, arr=c_fact_filtered)
        conditional_entropies['c_all_filtered'].append(np.mean(result))

    return conditional_entropies

def measure_execution_time(func, *args, **kwargs):
    import time
    start_time = time.time()
    func(*args, **kwargs)
    end_time = time.time()

    execution_time = end_time - start_time
    return execution_time

# NOTE all the concepts are threated like Bernoulli variable
def compute_concept_factorized_variance(
    c_fact_1: ndarray,
    c_fact_2: ndarray,
    p_w_x: ndarray,
):
    def bernoulli_std(p: ndarray, c: int):
        import math
        return math.sqrt(p[c] * (1 - p[c]))

    conditional_variances = { 'c1': list(), 'c2': list(), '(c1, c2)': list(), 'c': list() }
    
    c_fact_stacked = np.vstack([c_fact_1, c_fact_2])

    for c_fact, key in zip([c_fact_1, c_fact_2, c_fact_stacked, p_w_x], ['c1', 'c2', 'c', '(c1, c2)']):
        for c in range(c_fact.shape[1]):
            result = np.apply_along_axis(bernoulli_std, axis=1, arr=c_fact, c=c)
            conditional_variances[key].append(np.mean(result)**2)

    return conditional_variances

def produce_ece_curve(
    p: ndarray,
    pred: ndarray,
    true: ndarray,
    exp_mode: str,
    purpose: str = "labels",
    ece_mode: ECEMODE = ECEMODE.WHOLE,
    concept: int = None,
    suffix: str = ''
):
    ece = None

    # TODO: could be improved by a lot: too many arguments!
    if ece_mode == ECEMODE.FILTERED_BY_CONCEPT:
        ece_data = expected_calibration_error_by_concept(p, pred, true, concept)
    else:
        ece_data = expected_calibration_error(p, pred, true)

    if ece_data:
        ece, ece_bins = ece_data
        fprint(f"Expected Calibration Error (ECE) {exp_mode} on {purpose}", ece)
        concept_flag = True if purpose != "labels" else False
        produce_calibration_curve(ece_bins, ece, f"{purpose}_calibration_curve_{exp_mode}{suffix}", concept_flag)
    
    return ece

def generate_concept_labels(concept_labels: List[str]):
    # Generate all the product with repetition of size two of the concept labels  (which indeed are all the possible words)

    concept_labels_full = ["".join(comb) for comb in product(concept_labels, repeat=2)]
    concept_labels_single = ["".join(comb) for comb in product(concept_labels)]
    sklearn_concept_labels = [str(int(el)) for el in concept_labels_full]
    sklearn_concept_labels_single = [str(int(el)) for el in concept_labels_single]

    return (
        concept_labels_full,
        concept_labels_single,
        sklearn_concept_labels,
        sklearn_concept_labels_single,
    )

def convert_numpy_to_list(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, torch.Tensor):
        return obj.cpu().numpy().tolist()
    elif isinstance(obj, list):
        return [convert_numpy_to_list(item) for item in obj]
    else:
        return obj

def save_dump(args, kwargs, incomplete=False, eltype=''):
    file_path = f"dumps/{get_model_name(args)}-seed_{args.seed}-nens_{args.n_ensembles}-ood_{args.use_ood}-lambda_{args.lambda_h}.json"

    if incomplete:
        file_path = f"dumps/{get_model_name(args)}-seed_{args.seed}-nens_{args.n_ensembles}-ood_{args.use_ood}-lambda_{args.lambda_h}_incomplete_{eltype}_real-kl_{args.real_kl}.json"

    # Convert ndarrays to nested lists in the dictionary
    for key, value in kwargs.items():
        kwargs[key] = convert_numpy_to_list(value)

    del kwargs["ORIGINAL_WEIGHTS"]

    # Dump the dictionary into the file
    with open(file_path, 'w') as json_file:
        json.dump(kwargs, json_file)


def produce_confusion_matrices(
    c_true_cc: ndarray,
    c_pred_cc: ndarray,
    c_true: ndarray,
    c_pred: ndarray,
    sklearn_concept_labels: List[str],
    sklearn_concept_labels_single: List[str],
    n_facts: int,
    mode: str,
    suffix: str
):
    
    from itertools import product
    concept_labels = [''.join(comb) for comb in product([str(el) for el in range(10)], repeat=2)]
    concept_labels_single = [''.join(comb) for comb in product([str(el) for el in range(10)])]
    sklearn_concept_labels = [str(int(el)) for el in concept_labels]
    sklearn_concept_labels_single = [str(int(el)) for el in concept_labels_single]

    # extend them in order to have a single element: e.g. 03 means that the first element was associated to 0 while the second with 3
    c_extended_true = np.array(
        [int(str(first) + str(second)) for first, second in c_true_cc]
    )
    c_extended_pred = np.array(
        [int(str(first) + str(second)) for first, second in c_pred_cc]
    )

    # arrays of concepts one after the other eg. 0, 1, 2...
    c_extended_true_merged = np.array([int(str(el)) for el in c_true])
    c_extended_pred_merged = np.array([int(str(el)) for el in c_pred])

    fprint("--- Saving the RSs Confusion Matrix ---")

    produce_confusion_matrix(
        "RSs Confusion Matrix on Combined Concepts",
        c_extended_true,
        c_extended_pred,
        sklearn_concept_labels,
        f"confusion_matrix_combined_concept_{mode}_{suffix}",
        "true",
        1# n_facts,
    )

    produce_confusion_matrix(
        "RSs Confusion Matrix on Concepts",
        c_extended_true_merged,
        c_extended_pred_merged,
        sklearn_concept_labels,
        f"concept_confusion_matrix_{mode}_{suffix}",
        "true",
        1,
    )


def produce_alpha(
    mode: str,
    worlds_prob: ndarray,
    c_prb_1: ndarray,
    c_prb_2: ndarray,
    c_true: ndarray,
    c_true_cc: ndarray,
    n_facts: int,
    concept_labels: List[str],
    concept_labels_single: List[str],
    type: str
):
    fprint("--- Computing the probability of each world... ---")

    alpha_M, _ = get_alpha(worlds_prob, c_true_cc, n_facts=n_facts)

    produce_alpha_matrix(
        alpha_M, "p((C1, C2)| (G1, G2))", concept_labels, f"alpha_plot_{mode}", n_facts
    )

    # Only the single model produces the single ALPHA
    if type == EVALUATION_TYPE.NORMAL.value:
        words_prob_single_concept = np.concatenate((c_prb_1, c_prb_2), axis=0)
        alpha_M_single, _ = get_alpha_single(
            words_prob_single_concept, c_true, n_facts=n_facts
        )

        produce_alpha_matrix(
            alpha_M_single,
            "p(C | G)",
            concept_labels_single,
            f"alpha_plot_single_{mode}",
            1,
        )


TOTAL_METHODS = [member.value for member in EVALUATION_TYPE]

def test(model: MnistDPL, dataset: BaseDataset, args, **kwargs):

    # If I have to evaluate all
    if args.evaluate_all:
        if len(TOTAL_METHODS) == 0:
            print("Done total evaluation!...")

            if not os.path.exists("dumps"):
                # If not, create it
                os.makedirs("dumps")
            
            save_dump(args, kwargs)

            if args.wandb:
                wandb.finish()

        args.type = TOTAL_METHODS[0]
        TOTAL_METHODS.pop(0)

        print("Doing total evaluation on...", args.type, "remaining: ", TOTAL_METHODS)

    # Wandb
    if args.wandb is not None:
        fprint('\n---Wandb on\n')
        wandb.init(project=args.project, entity=args.wandb, 
                   name=str(args.model)+'_lasthope_'+ 
                        "_n_ens_"+str(args.n_ensembles) +
                        "_lambda_" + str(args.lambda_h),
                   config=args)

    # Default Setting for Training
    model.to(model.device)
    train_loader, val_loader, test_loader = dataset.get_data_loaders()

    # override the OOD if specified
    if args.use_ood:
        test_loader = dataset.ood_loader

    fprint("Loading network....")
    model = load_checkpoint(model, args, args.checkin)
    laplace_model = None
    ensemble = None

    # Add the original weights
    if 'ORIGINAL_WEIGHTS' not in kwargs:
        kwargs['ORIGINAL_WEIGHTS'] = [param.data.clone() for param in model.parameters()]

    # Check whether to apply softmax
    apply_softmax = False
    if args.model == 'mnistsl':
        apply_softmax = True

    if args.type == EVALUATION_TYPE.LAPLACE.value and args.skip_laplace:
        if args.evaluate_all:
            test(model, dataset, args, **kwargs)
        else:
            return

    # Retrieve the metrics according to the type of evaluation specified
    if args.type == EVALUATION_TYPE.NORMAL.value:
        fprint("## Not Bayesian model ##")
        y_true, c_true, y_pred, c_pred, p_cs, p_ys, p_cs_all, _ = evaluate_metrics(
            model, test_loader, args, last=True, apply_softmax=apply_softmax
        )
        _, c_true_cc, _, c_pred_cc, _, _, _, _ = evaluate_metrics(
            model, test_loader, args, last=True, concatenated_concepts=False, apply_softmax=apply_softmax
        )
    elif args.type == EVALUATION_TYPE.MC_DROPOUT.value:
        fprint("## Montecarlo dropout ##")
        (
            y_true,
            c_true,
            y_pred,
            c_pred,
            c_true_cc,
            c_pred_cc,
            p_cs,
            p_ys,
            p_cs_all,
            _,
        ) = montecarlo_dropout(model, test_loader, model.n_facts, 30, apply_softmax)
    elif args.type == EVALUATION_TYPE.BEARS.value or args.type == EVALUATION_TYPE.ENSEMBLE.value:
        if args.type == EVALUATION_TYPE.BEARS.value:
            fprint("### BEARS ###")
            args.deep_ens_kl = True
        else:
            fprint("### Deep Ensemble ###")
            args.deep_ens_kl = False

        fprint("Preparing the ensembles...")

        ensemble = deep_ensemble(
            seeds=[i + args.seed + 1 for i in range(args.n_ensembles)],
            dataset=dataset,
            num_epochs=args.n_epochs,
            args=args,
            val_loader=val_loader,
            epsilon=0.01,
            separate_from_others=args.deep_ens_kl,
            lambda_h=args.lambda_h,
            use_wandb=args.wandb,
            n_facts=model.n_facts,
            knowledge_aware_kl=args.knowledge_aware_kl,
            real_kl=args.real_kl
        )

        # ensemble predict
        (
            y_true,
            c_true,
            y_pred,
            c_pred,
            c_true_cc,
            c_pred_cc,
            p_cs,
            p_ys,
            p_cs_all,
            _,
        ) = ensemble_predict(ensemble, test_loader, model.n_facts, apply_softmax)
    elif args.type == EVALUATION_TYPE.LAPLACE.value:
        fprint("### Laplace Approximation ###")
        fprint("Preparing laplace model, please wait...")
        laplace_model = laplace_approximation(model, model.device, train_loader, val_loader)
        (
            y_true, c_true, y_pred, c_pred, 
            c_true_cc, c_pred_cc, p_cs, 
            p_ys, p_cs_all, _
        ) = laplace_prediction(
            laplace_model,
            model.device,
            test_loader,
            30,
            model.nr_classes, 
            model.n_facts,
            apply_softmax
        )
    else:
        raise IllegalArgumentError("Mode argument not valid")

    (
        concept_labels,
        concept_labels_single,
        sklearn_concept_labels,
        sklearn_concept_labels_single,
    ) = generate_concept_labels(dataset.get_concept_labels())

    fprint("Evaluating", args.type)

    # Print the distances
    if args.type == EVALUATION_TYPE.LAPLACE.value:
        # Get the ensembles for the inner model
        ensemble = laplace_model.model.model.get_ensembles(laplace_model, args.n_ensembles)
        
    if args.type != EVALUATION_TYPE.NORMAL.value and args.type != EVALUATION_TYPE.MC_DROPOUT.value:
        print("Currently in", args.type)

    # metrix, h(c|y) and concept confusion matrix
    mean_h_c, yac, cac, cf1, yf1 = print_metrics(y_true, y_pred, c_true, c_pred, p_cs_all, model.n_facts, args.type)

    # Log in Wandb
    if args.wandb is not None:
        ood_string = "-ood" if args.use_ood else ""
        to_log = {
            f"{args.type}-Mean-H(C)-test{ood_string}": mean_h_c,
            f"{args.type}-Acc-Y-test{ood_string}": yac,
            f"{args.type}-Acc-C-test{ood_string}": cac, 
            f"{args.type}-F1-C-test{ood_string}": cf1,
            f"{args.type}-F1-Y-test{ood_string}": yf1
        }
        wandb.log(to_log)

    if 'mean_hc' not in kwargs:
        kwargs['mean_hc'] = []
        kwargs['yac'] = []
        kwargs['cac'] = []
        kwargs['yac_hard'] = []
        kwargs['cac_hard'] = []
        kwargs['cf1'] = []
        kwargs['yf1'] = []
        
    kwargs['mean_hc'].append(mean_h_c)
    kwargs['yf1'].append(yf1)
    kwargs['yac'].append(yac)

    if args.type != EVALUATION_TYPE.NORMAL.value:
        kwargs['yac_hard'].append(cf1)
        kwargs['cac_hard'].append(cac)

    distance_model = None
    if args.type == EVALUATION_TYPE.LAPLACE.value:
        distance_model = laplace_model
    elif args.type == EVALUATION_TYPE.MC_DROPOUT.value:
        distance_model = model
    elif args.type == EVALUATION_TYPE.BEARS.value or args.type == EVALUATION_TYPE.ENSEMBLE.value:
        distance_model = ensemble

    # label and concept ece
    ece_y = produce_ece_curve(p_ys, y_pred, y_true, args.type, "labels")

    if args.type == EVALUATION_TYPE.NORMAL.value:
        worlds_prob, c_factorized_1, c_factorized_2, worlds_groundtruth = get_concept_probability(
            model, test_loader
        )  # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items)
        c_pred_normal = worlds_prob.argmax(axis=1)
        p_c_normal = worlds_prob.max(axis=1)

        ece = produce_ece_curve(p_c_normal, c_pred_normal, worlds_groundtruth, args.type, "concepts")

        mean_h_c, yac, cac, cf1, yf1 = print_metrics(y_true, y_pred, worlds_groundtruth, c_pred_normal, worlds_prob, model.n_facts, args.type)

        kwargs['yac_hard'].append(cf1)
        kwargs['cac_hard'].append(cac)

    else:
        ece = produce_ece_curve(p_cs, c_pred, c_true, args.type, "concepts")

    if 'ece' not in kwargs:
        kwargs['ece'] = []
        kwargs['ece y'] = []

    kwargs['ece'].append(ece)
    kwargs['ece y'].append(ece_y)

    # Log in Wandb
    if args.wandb is not None:
        ood_string = "-ood" if args.use_ood else ""
        to_log = {
            f"{args.type}-ECE-C-test{ood_string}": ece,
            f"{args.type}-ECE-Y-test{ood_string}": ece_y,
        }
        wandb.log(to_log)

    # Evaluate all the models in the ensemble
    if args.type == EVALUATION_TYPE.BEARS.value or args.type == EVALUATION_TYPE.LAPLACE.value or args.type == EVALUATION_TYPE.ENSEMBLE.value:
        fprint(f"{args.type} evaluation...")

        if args.type == EVALUATION_TYPE.LAPLACE.value:
            # Get the ensembles for the inner model
            ensemble = laplace_model.model.model.get_ensembles(laplace_model, 30)

        for i, model in enumerate(ensemble):
            fprint(f"-- Model {i} --")
            y_true_ens, c_true_ens, y_pred_ens, c_pred_ens, p_cs_ens, p_ys_ens, p_cs_all_ens, _ = evaluate_metrics(model, test_loader, args, last=True, apply_softmax=apply_softmax)
            _, c_true_cc_ens, _, c_pred_cc_ens, _, _, _, _ = evaluate_metrics(
                model, test_loader, args, last=True, concatenated_concepts=False, apply_softmax=apply_softmax
            )

            mean_sh_c, syac, scac, scf1, syf1 = print_metrics(y_true_ens, y_pred_ens, c_true_ens, c_pred_ens, p_cs_all_ens, model.n_facts, args.type)

            sece_y = produce_ece_curve(p_ys_ens, y_pred_ens, y_true_ens, args.type, "labels", ECEMODE.WHOLE, None, f'_{args.type}_{i}')
            sece = produce_ece_curve(p_cs_ens, c_pred_ens, c_true_ens, args.type, "concepts", ECEMODE.WHOLE, None, f'_{args.type}_{i}')

            if args.wandb is not None:
                ood_string = "-ood" if args.use_ood else ""
                to_log = {
                    f"{args.type}_model_{i}-Mean-H(C)-test{ood_string}": mean_sh_c,
                    f"{args.type}_model_{i}-Acc-Y-test{ood_string}": syac,
                    f"{args.type}_model_{i}-Acc-C-test{ood_string}": scac, 
                    f"{args.type}_model_{i}-F1-C-test{ood_string}": scf1,
                    f"{args.type}_model_{i}-F1-Y-test{ood_string}": syf1,
                    f"{args.type}_model_{i}-ECE-C-test{ood_string}": sece,
                    f"{args.type}_model_{i}-ECE-Y-test{ood_string}": sece_y,
                }
                wandb.log(to_log)

    fprint("--- Computing the probability of each world... ---")

    c_factorized_1, c_factorized_2 = None, None

    # TODO: should be pc
    if args.type == EVALUATION_TYPE.MC_DROPOUT.value:
        worlds_prob = get_concept_probability_mcdropout(
            model, 
            test_loader, 
            activate_dropout, 
            args.n_ensembles
        ) # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items)

        # Obtain the factorized probabilities
        c_factorized_1, c_factorized_2, gt_factorized = get_concept_probability_factorized_mcdropout(
            model,
            test_loader,
            activate_dropout,
            args.n_ensembles
        )

    elif args.type == EVALUATION_TYPE.BEARS.value or args.type == EVALUATION_TYPE.ENSEMBLE.value:
        worlds_prob = get_concept_probability_ensemble(
            ensemble, 
            test_loader
        ) # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items)

        # Obtain the factorized probabilities
        c_factorized_1, c_factorized_2, gt_factorized = get_concept_probability_factorized_ensemble(
            ensemble,
            test_loader
        )
    elif args.type == EVALUATION_TYPE.LAPLACE.value:
        worlds_prob = get_concept_probability_laplace(
            model.device, 
            test_loader, 
            laplace_model, 
            args.n_ensembles
        ) # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items)

        # Obtain the factorized probabilities
        c_factorized_1, c_factorized_2, gt_factorized = get_concept_probability_factorized_laplace(
            model.device,
            test_loader,
            laplace_single_prediction,
            laplace_model,
            model.nr_classes,
            model.n_facts
        )
    else:
        # NORMAL MODE
        worlds_prob, c_factorized_1, c_factorized_2, worlds_groundtruth = get_concept_probability(
            model, test_loader
        )  # 2 arrays of size 256, 10 (concept 1 and concept 2 for all items)

    # Change it for the concept factorized entropy and variance
    if args.type == EVALUATION_TYPE.NORMAL.value:
        p_cs_all = worlds_prob
        gt_factorized = c_true

    # factorized probability concatenated
    c_factorized_full = np.concatenate((c_factorized_1, c_factorized_2), axis=0)
    # maximum element probability for the ECE count
    c_factorized_max_p = np.max(c_factorized_full, axis=1) 
    # factorized predictions with argmax
    c_pred_factorized_full = np.argmax(c_factorized_full, axis=1)

    single_concepts_ece = []
    # ECE per concept NOTE factorized, otherwise only world is possible
    for c in concept_labels_single:
        ece_single_concept = produce_ece_curve(
            c_factorized_max_p, c_pred_factorized_full, gt_factorized, args.type, f"concepts {c}", ECEMODE.FILTERED_BY_CONCEPT, int(c)
        )
        single_concepts_ece.append(ece_single_concept)

    # add single concepts ECE
    kwargs[f"{args.type} ece single concept"] = single_concepts_ece

    cfe = compute_concept_factorized_entropy(
        c_factorized_1,
        c_factorized_2,
        p_cs_all, # equal to c_pred in ensembles [#dati, facts^2]
    )
    
    cfvar = compute_concept_factorized_variance(
        c_factorized_1,
        c_factorized_2,
        p_cs_all, # equal to c_pred in ensembles [#dati, facts^2]
    )

    cac, cf1 = evaluate_mix(c_pred_factorized_full, gt_factorized)
    kwargs['cf1'].append(cf1)
    kwargs['cac'].append(cac)

    if not any(key in kwargs for key in ['e_c1','e_c2','e_c', 'e_(c1, c2)']):
        kwargs['e_c1'] = list()
        kwargs['e_c2'] = list()
        kwargs['e_c'] = list()
        kwargs['e_(c1, c2)'] = list()

        kwargs['var_c1'] = list()
        kwargs['var_c2'] = list()
        kwargs['var_c'] = list()
        kwargs['var_(c1, c2)'] = list()

    kwargs['e_c1'].append(cfe['c1'])
    kwargs['e_c2'].append(cfe['c2'])
    kwargs['e_c'].append(cfe['c'])  
    kwargs['e_(c1, c2)'].append(cfe['(c1, c2)'])

    kwargs['var_c1'].append(cfvar['c1'])
    kwargs['var_c2'].append(cfvar['c2'])
    kwargs['var_c'].append(cfvar['c'])  
    kwargs['var_(c1, c2)'].append(cfvar['(c1, c2)'])

    concept_counter_list, concept_acc_list = concept_accuracy(c_factorized_1, c_factorized_2, gt_factorized)

    if args.type == EVALUATION_TYPE.NORMAL.value:
        p_cs_all = worlds_prob
        c_true = worlds_groundtruth

    world_counter_list, world_acc_list = world_accuracy(p_cs_all, c_true, model.n_facts)

    if not any(key in kwargs for key in ['c_acc_count', 'c_acc', 'w_acc_count', 'w_acc']):
        kwargs['c_acc_count'] = list()
        kwargs['c_acc'] = list()
        kwargs['w_acc_count'] = list()
        kwargs['w_acc'] = list()
        kwargs['c_ova_filtered'] = list()
        kwargs['c_all_filtered'] = list()

    kwargs['c_acc_count'].append(concept_counter_list)
    kwargs['c_acc'].append(concept_acc_list)
    kwargs['w_acc_count'].append(world_counter_list)
    kwargs['w_acc'].append(world_acc_list)

    e_per_c = compute_entropy_per_concept(c_factorized_full, gt_factorized)

    kwargs['c_ova_filtered'].append(e_per_c['c_ova_filtered'])
    kwargs['c_all_filtered'].append(e_per_c['c_all_filtered'])

    file_path = f"dumps/{get_model_name(args)}-seed_{args.seed}-{args.type}-nens_{args.n_ensembles}-ood_{args.use_ood}-lambda_{args.lambda_h}.csv"

    # save csv
    save_csv(y_true, c_true, y_pred, c_pred, 
        c_true_cc, c_pred_cc, p_cs, 
        p_ys, p_cs_all, c_factorized_1, c_factorized_2, 
        worlds_prob, gt_factorized, file_path
    )

    if args.evaluate_all:
        test(model, dataset, args, **kwargs)
    else:
        save_dump(args, kwargs, incomplete=True, eltype=args.type)


def save_csv(
        y_true, c_true, y_pred, c_pred, 
        c_true_cc, c_pred_cc, p_cs, 
        p_ys, p_cs_all, c_factorized_1, c_factorized_2, 
        worlds_prob, gt_factorized, file_path
    ):
    import csv

    gt_factorized = np.reshape(gt_factorized,(int(gt_factorized.shape[0]/2), 2))

    with open(file_path, 'a', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        # save to the file
        for j in range(len(y_true)):
            row = []
            row.append(y_true[j])
            row.append(c_true[j])
            row.append(y_pred[j])
            row.append(c_pred[j])
            for k in range(c_true_cc.shape[1]):
                row.append(c_true_cc[j][k])
            for k in range(c_pred_cc.shape[1]):
                row.append(c_pred_cc[j][k])
            row.append(p_cs[j])
            row.append(p_ys[j])
            for k in range(p_cs_all.shape[1]):
                row.append(p_cs_all[j][k])
            for k in range(c_factorized_1.shape[1]):
                row.append(c_factorized_1[j][k])
            for k in range(c_factorized_2.shape[1]):
                row.append(c_factorized_2[j][k])
            for k in range(gt_factorized.shape[1]):
                row.append(gt_factorized[j][k])
            for k in range(worlds_prob.shape[1]):
                row.append(worlds_prob[j][k])
            csv_writer.writerow(row)