from typing import Dict, Optional

import numpy as np
from scipy.special import beta, binom

from valuation_func.base_valuation_func import (
    BaseValuationFunc,
)
from valuation_func.sampler import SystematicSampler


class ShapleyValuationFunc(BaseValuationFunc):
    def __init__(
        self,
        sampler: SystematicSampler = None,
        marg_contrib_dict: Optional[Dict] = None,
    ):
        super().__init__(sampler, marg_contrib_dict)

    def compute_weight(self):
        num_points = self.num_points
        weights = np.array([1 / num_points for _ in range(num_points)])
        return weights


class BetaShapleyValuationFunc(BaseValuationFunc):
    def __init__(
        self,
        alpha: int,
        beta: int,
        marg_contrib_dict: Dict,
        sampler: SystematicSampler = None,
    ):
        super().__init__(sampler, marg_contrib_dict)
        self.alpha = alpha
        self.beta = beta

    def compute_weight(self):
        beta_constant = beta(self.alpha, self.beta)
        weights = np.array(
            [
                binom(self.num_points - 1, j)
                * beta(j + self.beta, self.num_points - j + self.alpha - 1)
                / beta_constant
                for j in range(self.num_points)
            ]
        )
        return weights


class BanzhafValuationFunc(BaseValuationFunc):
    def __init__(
        self,
        sampler: SystematicSampler = None,
        marg_contrib_dict: Optional[Dict] = None,
    ):
        super().__init__(sampler, marg_contrib_dict)

    def compute_weight(self):
        n = self.num_points
        weights = np.array([2 ** -(n - 1) * binom(n - 1, j) for j in range(n)])
        return weights