import numpy as np
import torch

from src.utils.shapley_procedure.weights import compute_kernelshap_weights_for_pre_coalitions


def build_coalitions(num_features: int, num_coalitions: int, sampling_method: str = "kernelshap_weights") -> np.ndarray:
    max_range = min(2 ** num_features, 2 ** 63 - 1)

    if num_coalitions >= max_range:
        configs = np.arange(max_range)
        return _generate_coalitions_from_integers(configs, num_features)
    if sampling_method == "subsampling":
        return _sample_coalitions_from_binomial(num_coalitions=num_coalitions, num_features=num_features)
    elif sampling_method == "kernelshap_weights":
        return _sample_coalitions_from_kernelshap_weights(num_features=num_features, num_coalitions=num_coalitions)


def _generate_coalitions_from_integers(indices: np.array, num_features: int = 10) -> np.ndarray:
    """sample integers and then turn them into binary code"""
    Z = np.zeros((indices.shape[0], num_features))
    rest = indices
    valid_rows = rest > 0
    while True:
        set_to_1 = np.floor(np.log2(rest)).astype(int)
        set_to_1_prime = set_to_1[valid_rows][:, np.newaxis]
        p = Z[valid_rows, :]
        np.put_along_axis(p, set_to_1_prime, 1, axis=1)
        Z[valid_rows, :] = p
        rest = rest - 2 ** (np.clip(set_to_1, 0, np.inf))
        valid_rows = rest > 0
        if valid_rows.sum() == 0:
            return Z


def _sample_coalitions_from_binomial(num_coalitions: int, num_features: int) -> np.ndarray:
    """sampling coalitions using binomial distribution and remove duplicates afterward
    """
    Z = np.random.binomial(size=(num_coalitions, num_features), n=1, p=0.5)
    b = 2 ** np.arange(0, num_features)
    unique_ref = (b * Z).sum(axis=1)
    _, idx = np.unique(unique_ref, return_index=True)
    return Z[idx, :]


def _sample_coalitions_from_kernelshap_weights(num_features: int, num_coalitions: int):
    """sampling coalition based on the kernelshap weights on the coalitions"""
    max_range = min(2 ** num_features, 2 ** 63 - 1)

    pre_coalitions = torch.from_numpy(
        _generate_coalitions_from_integers(np.arange(max_range), num_features)[1:-1]  # remove edge case
    ).bool()

    weights = compute_kernelshap_weights_for_pre_coalitions(num_features, pre_coalitions)

    coalition_indices = np.random.choice(range(len(pre_coalitions)),
                                         num_coalitions,
                                         p=weights / weights.sum(),
                                         )

    return pre_coalitions.numpy()[coalition_indices]
