import itertools
import numpy as np
from scipy.special import comb, beta, comb
from scipy.stats import betabinom

class SemiValuation:
    def __init__(self, v, datasets, weights, memo=None, mc_count=100, seed=42, method = "exact"):
        self.v = v
        self.datasets = list(datasets)
        self.n = len(datasets)
        self.players = list(range(self.n))
        self.memo = {} if memo is None else memo

        w = np.array(weights, float)
        self.w = w
        assert w.size == self.n and np.isclose(w.sum(), 1.0)
        # adjusted weight for subsets of size c
        self.adj_w = w / np.array([comb(self.n-1, c, exact=True) for c in range(self.n)])

        self.num_samples = mc_count 
        self.seed = seed

        self.method = method
        if method == "svarm": # only works for shapley value
            self.semivalues = self._svarm_semivalue()
        else:
            self.semivalues = self._exact_semivalue()

    def get_value(self, subset):
        key = tuple(sorted(subset))
        if key not in self.memo:
            data = [self.datasets[i] for i in key]
            self.memo[key] = self.v(*data)
        return self.memo[key]

    def _exact_semivalue(self):
        b = len(self.get_value(()))
        φ = np.zeros((self.n, b))
        for i in self.players:
            others = [j for j in self.players if j != i]
            for c in range(self.n):
                w = self.adj_w[c]
                for S in itertools.combinations(others, c):
                    φ[i] += w * (self.get_value((*S, i)) - self.get_value(S))
        return φ

    def _svarm_semivalue(self):
        rng = np.random.default_rng(self.seed)
        b = len(self.get_value(()))
        φ = np.zeros((self.n, b))

        # SVARM setup
        H_n = sum(1 / s for s in range(1, self.n + 1))
        phi_plus = np.zeros((self.n, b))
        phi_minus = np.zeros((self.n, b))
        c_plus = np.zeros(self.n)
        c_minus = np.zeros(self.n)

        for _ in range(self.num_samples):
            # Sample A⁺
            s_plus = rng.choice(range(1, self.n + 1), p=[1 / (s * H_n) for s in range(1, self.n + 1)])
            A_plus = rng.choice(self.players, size=s_plus, replace=False)
            v_plus = self.get_value(tuple(sorted(A_plus)))

            for i in A_plus:
                phi_plus[i] += v_plus
                c_plus[i] += 1

            # Sample A⁻
            s_minus = rng.choice(range(0, self.n), p=[1 / ((self.n - s) * H_n) for s in range(0, self.n)])
            A_minus = rng.choice(self.players, size=s_minus, replace=False)
            A_minus_set = set(A_minus)
            others = [i for i in self.players if i not in A_minus_set]
            v_minus = self.get_value(tuple(sorted(A_minus)))

            for i in others:
                phi_minus[i] += v_minus
                c_minus[i] += 1

        # Final Shapley estimate = E[v(S ∪ {i})] - E[v(S)]
        for i in self.players:
            if c_plus[i] > 0:
                phi_plus[i] /= c_plus[i]
            if c_minus[i] > 0:
                phi_minus[i] /= c_minus[i]
            φ[i] = phi_plus[i] - phi_minus[i]

        return φ

    def update_player_dataset(self, i, new_data):
        self.datasets[i] = new_data
        self.memo = {k: v for k, v in self.memo.items() if i not in k}
        if self.method == "svarm":
            self.semivalues = self._svarm_semivalue()
        else:
            self.semivalues = self._exact_semivalue()

class ShapleyValuation(SemiValuation):
    def __init__(self, v, datasets, memo=None):
        weights = np.ones(len(datasets)) / len(datasets)
        super().__init__(v, datasets, weights, memo)

class BetaShapleyValuation(SemiValuation):
    def __init__(self, v, datasets, a=16, b=1, memo=None):
        weights = betabinom.pmf(np.arange(len(datasets)), len(datasets)-1, b, a)
        super().__init__(v, datasets, weights, memo)

class IndivValuation(SemiValuation):
    def __init__(self, v, datasets, memo=None):
        weights = np.zeros(len(datasets)) / len(datasets)
        weights[0] = 1
        super().__init__(v, datasets, weights, memo)