from models.model_interface import ModelInterface
from data.data_interface import DataInterface
import pandas as pd
from recourse_interface import RecourseInterface
import itertools


class FairRecourse():
    def __init__(self, model: ModelInterface, data_interface: DataInterface, recourse_interface: RecourseInterface, protected_attr_col: str) -> None:
        """
        TODO: docstring
        """
        self._model = model
        self._data_interface = data_interface
        self._recourse_interface = recourse_interface
        train_dataset, test_dataset, _, _ = data_interface.get_train_test_split()
        self._train_dataset = train_dataset
        self._test_dataset = test_dataset
        self._prot_attr = protected_attr_col

    def generate_subset_interfaces(self):
        model_interfaces = []
        prot_groups = pd.unique(self._train_dataset[self._prot_attr])
        prot_groups_subsets = self.all_group_subsets(prot_groups)
        for subset in prot_groups_subsets:
            df_feats_g = self._train_dataset[self._train_dataset[self._prot_attr] in subset]
            df_label_g = self._test_dataset.loc[df_feats_g.index]
            di_g = self._data_interface.copy_change_data(
                pd.concat([df_feats_g, df_label_g], axis=1))
            recourse_interface_g = self._recourse_interface()
            pass

    def all_group_subsets(self, li: list):
        subsets = []
        for i in range(0, len(li) + 1):
            for subset in itertools.combinations(li, i):
                if len(list(subset)) > 0:
                    subsets.append(list(subset))
        return subsets
