"""
Based on "Disentangling by Factorising" (https://github.com/nmichlo/disent/blob/main/disent/metrics/_dci.py).
"""

import logging
import scipy
import scipy.stats
import numpy as np
from src.utils.seed import set_seed
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import os

logger = logging.getLogger(__name__)


def metric_dci(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        args,
        show_progress=False,
        continuous_factors=False,
):
    logger.info(
        "*********************DCI Disentanglement Evaluation*********************"
    )
    scores = compute_dci(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        show_progress,
        continuous_factors,
        args,
    )

    return scores


def compute_dci(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        show_progress,
        continuous_factors,
        args,
):
    importance_matrix, train_err, test_err = compute_importance_gbt(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        show_progress,
        continuous_factors,
        args,
    )
    assert importance_matrix.shape[0] == train_latents.shape[0]
    assert importance_matrix.shape[1] == train_factors.shape[0]

    disentanglement = disentangle(importance_matrix)
    completeness = complete(importance_matrix)

    return train_err, test_err, disentanglement, completeness, importance_matrix


def compute_importance_gbt(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        show_progress,
        continuous_factors,
        args,
):
    num_factors = train_factors.shape[0]
    num_latents = train_latents.shape[0]
    importance_matrix = np.zeros(shape=[num_latents, num_factors], dtype=np.float64)
    train_loss, test_loss = [], []

    # Use ThreadPoolExecutor for parallel processing
    with ThreadPoolExecutor(max_workers=min(num_factors, os.cpu_count())) as executor:
        # Prepare futures for parallel processing
        futures = []
        for i in range(num_factors):
            futures.append(executor.submit(_process_single_factor,
                                           train_latents, train_factors, test_latents, test_factors,
                                           i, continuous_factors, args))

        # Collect results
        for future in tqdm(as_completed(futures), total=num_factors, disable=not show_progress):
            result = future.result()
            importance_matrix[:, result['factor_index']] = result['importance']
            train_loss.append(result['train_loss'])
            test_loss.append(result['test_loss'])

    return importance_matrix, np.mean(train_loss), np.mean(test_loss)


def _process_single_factor(
        train_latents,
        train_factors,
        test_latents,
        test_factors,
        factor_index,
        continuous_factors,
        args
):
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.ensemble import GradientBoostingRegressor

    set_seed(args)
    model = (
        GradientBoostingRegressor()
        if continuous_factors
        else GradientBoostingClassifier()
    )

    model.fit(train_latents.T, train_factors[factor_index, :])
    importance = np.abs(model.feature_importances_)

    train_loss = np.mean(model.predict(train_latents.T) == train_factors[factor_index, :])
    test_loss = np.mean(model.predict(test_latents.T) == test_factors[factor_index, :])

    return {
        'factor_index': factor_index,
        'importance': importance,
        'train_loss': train_loss,
        'test_loss': test_loss
    }


def disentangle(importance_matrix):
    per_code = disentanglement_per_code(importance_matrix)
    if importance_matrix.sum() == 0.0:
        importance_matrix = np.ones_like(importance_matrix)
    code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum()
    return np.sum(per_code * code_importance)


def disentanglement_per_code(importance_matrix):
    # (latents_dim, factors_dim)
    return 1.0 - scipy.stats.entropy(
        importance_matrix.T + 1e-11, base=importance_matrix.shape[1]
    )


def complete(importance_matrix):
    per_factor = completeness_per_factor(importance_matrix)
    if importance_matrix.sum() == 0.0:
        importance_matrix = np.ones_like(importance_matrix)
    factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum()
    return np.sum(per_factor * factor_importance)


def completeness_per_factor(importance_matrix):
    # (latents_dim, factors_dim)
    return 1.0 - scipy.stats.entropy(
        importance_matrix + 1e-11, base=importance_matrix.shape[0]
    )

