from ..cf_generator import CounterfactualGenerator
import gurobipy as gp
from gurobipy import GRB
from counterfactual_explanations.input_properties import InputProperties
import numpy as np
from models.abstract_model import AbstractModel
from models.gradientboosting_sklearn import GradientBoostingSKLearn
from models.decisiontree_sklearn import DecisionTreeSKLearn
from models.randomforest_sklearn import RandomForestSKLearn

class NearestNeighbourCF(CounterfactualGenerator):
    def __init__(self, model, input_properties: InputProperties, config, save_dir=".", use_pregenerated=True):
        super().__init__(model, input_properties, config, save_dir, use_pregenerated)
        self.ord = self.config.get('ord', 2)
    
    def setup(self, X_train, y_train, X_calib, y_calib):
        self.predictions = np.argmax(self.model.predict(X_train), axis=1)

    def generate_counterfactual(self, x, y_target):
        target_class_points = self.X_train[np.where(self.y_train == y_target) and np.where(self.predictions == y_target)]
        distances = np.linalg.norm(target_class_points - x, ord=self.ord, axis=1)
        min_dist_index = np.argmin(distances)
        return target_class_points[min_dist_index]

        