from counterfactual_explanations.cf_generator import CounterfactualGenerator
from counterfactual_explanations.input_properties import InputProperties
from library.featureTweakPy import *
from models.randomforest_sklearn import RandomForestSKLearn

def distance_cost_l1(a, b):
    return np.linalg.norm(a-b, ord=1)

def distance_cost_l2(a, b):
    return np.linalg.norm(a-b, ord=2)

class FeatureTweakGenerator(CounterfactualGenerator):
    def __init__(self, model, input_properties: InputProperties, config, save_dir=".", use_pregenerated=True):
        super().__init__(model, input_properties, config, save_dir, use_pregenerated)
        assert isinstance(self.model, RandomForestSKLearn)

        self.epsilon = self.config.get('epsilon', 0.1)
        self.cost_fn = self.config.get('cost_fn', distance_cost_l1)
        
    def generate_counterfactual(self, x, y_target):
        x_cf = feature_tweaking(self.model.model, x, self.input_properties.get_labels(), y_target, self.epsilon, self.cost_fn)
        return x_cf