import itertools
import numpy as np
from scipy.special import comb
from scipy.stats import betabinom
from itertools import chain, combinations

class SemiValuationM:
    def __init__(self, v, datasets, weights):
        """
        :param v: A value function that takes a variable number of datasets as arguments.
                  Returns a vector of length b.
        :param datasets: A list of datasets for each player.
        :param weights: A weight vector specifying the importance of coalitions of different sizes.
                        The weights should sum to 1.
        """
        self.v = v
        self.datasets = datasets[:]  # Store a COPY of the datasets for each player
        self.num_players = len(datasets)
        self.players = list(range(self.num_players))
        self.memo = {}  # Dictionary to store memoized values
        self.memomodel = {} # Dictionary to store memoized models

        # Normalize weights
        self.weights = np.array(weights)
        assert len(self.weights) == self.num_players, "Weight vector length must equal the number of players."
        assert np.isclose(np.sum(self.weights), 1.0), "Weights must sum to 1."

        # Precompute adjusted weights for coalitions of each size
        self.adjusted_weights = [
            self.weights[c] / comb(self.num_players - 1, c, exact=True)
            for c in range(self.num_players)
        ]
        self.semivalues = self.calculate_semivalue()

    def _compute_v(self, subset):
        subset = tuple(sorted(subset))
        datasets_subset = [self.datasets[j] for j in subset]
        trained_model = self.memomodel.get(subset)  # Return None if not in memomodel
        return subset, *self.v(datasets_subset, trained=trained_model)

    def get_value(self, subset):
        subset = tuple(sorted(subset))
        if subset not in self.memo:
            datasets_subset = [self.datasets[j] for j in subset]
            if subset not in self.memomodel:
                self.memo[subset], self.memomodel[subset] = self.v(datasets_subset)  # Expecting a vector of length b
            else:
                self.memo[subset], _ = self.v(datasets_subset, trained=self.memomodel[subset])
        return self.memo[subset]

    def calculate_semivalue(self):

        n = self.num_players
        b = len(self.get_value(()))  # Determine the length of the value vector
        semivalues = np.zeros((n, b))  # Initialize semivalues matrix

        for i in self.players:
            semivalue = np.zeros(b)
            players = self.players[:i] + self.players[i+1:]

            for subset in itertools.chain.from_iterable(itertools.combinations(players, r) for r in range(n)):
                subset_size = len(subset)
                weight = self.adjusted_weights[subset_size]
                subset_args = subset
                marginal_contribution = self.get_value(subset_args + (i,)) - self.get_value(subset_args)
                semivalue += weight * marginal_contribution

            semivalues[i] = semivalue

        return semivalues

    def update_player_dataset(self, i, new_dataset):
        self.datasets[i] = new_dataset
        for subset in list(self.memomodel.keys()):
            if i in subset:
                self.memo.pop(subset, None)
                del self.memomodel[subset]  # Remove the memoized value to force recomputation
        self.semivalues = self.calculate_semivalue()

    def update_v(self, v, recompute=False):
        self.v = v
        self.memo = dict()
        if recompute:
            self.semivalues = self.calculate_semivalue()


class ShapleyValuationM(SemiValuationM):
    def __init__(self, v, datasets):
        weights = np.ones(len(datasets)) / len(datasets)
        super().__init__(v, datasets, weights)

class BetaShapleyValuationM(SemiValuationM):
    def __init__(self, v, datasets, a=16, b=1):
        weights = betabinom.pmf(np.arange(len(datasets)), len(datasets)-1, b, a)
        super().__init__(v, datasets, weights)
