from counterfactual_explanations.cf_generator import CounterfactualGenerator
from counterfactual_explanations.input_properties import InputProperties
from cfxplorer import Focus
from models.randomforest_sklearn import RandomForestSKLearn

import tensorflow as tf

class FOCUSGenerator(CounterfactualGenerator):
    def __init__(self, model, input_properties: InputProperties, config, save_dir=".", use_pregenerated=True):
        super().__init__(model, input_properties, config, save_dir, use_pregenerated)
        assert isinstance(self.model, RandomForestSKLearn)
        self.distance_func = self.config.get('distance_func', 'l1')
        self.n_iter = self.config.get('n_iter', 100)
        
    def generate_counterfactual(self, x, y_target):
        self.focus = Focus(distance_function=self.distance_func, num_iter=self.n_iter, optimizer=tf.keras.optimizers.Adam(), verbose=0)
        x_cf = self.focus.generate(self.model.model, x.reshape(1, -1).astype(float))[0]
        return x_cf

    def generate_counterfactuals(self, x_factuals, y_targets):
        self.focus = Focus(distance_function=self.distance_func, num_iter=self.n_iter, optimizer=tf.keras.optimizers.Adam(), verbose=0)
        x_cf = self.focus.generate(self.model.model, x_factuals.astype(float))
        return x_cf