import numpy as np
import math

from ..base_estimator import BaseEstimator
from ..utils.p_generator import get_p


class WeightedSHAPEstimator(BaseEstimator):
    """
    Permutation‑based Weighted‑SHAP estimator.

        φ_i  =  ∑_{k=0}^{n-1}  m_k · E[ U(S_k ∪ {i}) − U(S_k) ]
        m_k  =  C(n-1, k) · p_k

    """

    def __init__(self, model, baseline: np.ndarray, weighting: str):
        super().__init__(model, baseline, weighting)

        self.n_feat = self.baseline.size if self.baseline.ndim == 1 else self.baseline.shape[1]
        p = get_p(self.n_feat, weighting)                       # length n
        combs = np.array([math.comb(self.n_feat - 1, k) for k in range(self.n_feat)],
                 dtype=float)
        self.m_vec = self.n_feat * p * combs                            #  m_k = C(n-1,k) p_k
        self.base_vec = self.baseline.reshape(-1).astype(float)
        self.rng = np.random.default_rng()


    def _mask_instance(self, x_vec, mask):
        return (np.where(mask, x_vec, self.base_vec))[None, :]

    def explain(self, explicand: np.ndarray, num_samples: int) -> np.ndarray:
        num_samples = num_samples // self.n_feat
        x_vec = explicand.reshape(-1).astype(float)

        phi_acc = np.zeros((self.n_feat), dtype=float)

        for _ in range(num_samples):
            perm = self.rng.permutation(self.n_feat)
            mask = np.zeros(self.n_feat, dtype=bool)

            prev_pred = self.model.predict(self._mask_instance(x_vec, mask))

            for k, feat in enumerate(perm):
                mask[feat] = True
                curr_pred = self.model.predict(self._mask_instance(x_vec, mask))

                delta = curr_pred - prev_pred
                phi_acc[feat] += self.m_vec[k] * delta
                prev_pred = curr_pred

        phi = phi_acc / num_samples

        return phi