from abc import ABC, abstractmethod
from typing import Dict, Optional

import numpy as np

from valuation_func.sampler import SystematicSampler


class BaseValuationFunc(ABC):
    def __init__(
        self,
        sampler: SystematicSampler = None,
        marg_contrib_dict: Optional[Dict] = None,
    ):
        self.sampler = sampler
        self.marg_contrib_dict = marg_contrib_dict
        if self.marg_contrib_dict is not None:
            first_key = next(iter(self.marg_contrib_dict))
            self.num_points = self.marg_contrib_dict[first_key].shape[1]
        else:
            self.num_points = self.sampler.num_points

    @abstractmethod
    def compute_weight(self):
        pass

    def compute_data_values(self):
        if self.marg_contrib_dict is None:
            marg_contrib_dict = self.sampler.compute_marginal_contributions_for_all()
            self.marg_contrib_dict = marg_contrib_dict
        data_values = {}
        for config_key, marg_contrib in self.marg_contrib_dict.items():
            weighted_marg_contrib = marg_contrib * self.compute_weight()[:, np.newaxis]
            data_values[config_key] = np.sum(weighted_marg_contrib, axis=0)
        return data_values