from .bbox import AbstractBBox
import pandas as pd
import numpy as np

from .dataset import TabularDataset, Dataset
from .encoder_decoder import ColumnTransformerEnc, EncDec
from .neighgen import GeneticGenerator
from .neighgen.neighborhood_generator import NeighborhoodGenerator
from .neighgen.random import RandomGenerator
from .surrogate import DecisionTreeSurrogate, Surrogate
from .lore import Lore

class ConceptLore(Lore):

    def __init__(self, bbox: AbstractBBox, encoder: EncDec,
                 generator: NeighborhoodGenerator, surrogate: Surrogate):
        """
        Creates a new instance of the LORE method.


        :param bbox: The black box model to be explained wrapped in a ``AbstractBBox`` object.
        :param dataset:
        :param encoder:
        :param generator:
        :param surrogate:
        """

        # super().__init__()
        self.bbox = bbox
        # self.descriptor = dataset.descriptor
        self.descriptor = None
        self.encoder = encoder
        self.generator = generator
        self.surrogate = surrogate
        # self.class_name = dataset.class_name

    def explain(self, x: np.ndarray):
        """
        Explains a single instance of the dataset.
        :param x: an array with the values of the instance to explain (the target class is not included)
        :return:
        """
        # map the single record in input to the encoded space
        [z] = self.encoder.encode([x])
        # print(z)
        # generate a neighborhood of instances around the projected instance `z`
        # dec_neighbor = self.generator.generate(z, 1000)
        # neighbour = self.encoder.encode(dec_neighbor)
        neighbour = self.generator.generate(z, 1000)
        dec_neighbor:list  =self.encoder.decode(neighbour)
        # split neighbor in features and class using train_test_split
        # neighb_train_X = dec_neighbor[:, :]
        neighb_train_X = dec_neighbor.copy()
        neighb_train_y = self.bbox.predict(neighb_train_X)
        neighb_train_yb = self.encoder.encode_target_class(neighb_train_y.reshape(-1, 1)).squeeze()

        # train the surrogate model on the neighborhood
        self.surrogate.train(neighbour, neighb_train_yb)

        # get the rule for the instance `z`, decode using the encoder class
        rule = self.surrogate.get_rule(z, self.encoder)
        # print('rule', rule)

        self.crules, self.deltas = self.surrogate.get_counterfactual_rules(z, neighbour, neighb_train_yb, self.encoder)

        return {'x': x, 'rule': rule, 'counterfactuals': self.crules, 'deltas': self.deltas}


