
from abc import ABC, abstractmethod
from counterfactual_explanations.input_properties import InputProperties
import numpy as np

from models.gradientboosting_sklearn import GradientBoostingSKLearn
from models.decisiontree_sklearn import DecisionTreeSKLearn
from models.randomforest_sklearn import RandomForestSKLearn
from models.abstract_model import AbstractModel
from counterfactual_explanations.dim_reduction import DimensionalityReduction
from typing import Type

from tqdm import tqdm
import json

class CounterfactualGenerator(ABC):
    def __init__(self, model: AbstractModel, input_properties: InputProperties, config, save_dir=".", use_pregenerated=True):
        self.model = model
        self.input_properties = input_properties
        self.save_dir = save_dir
        self.config = config
        self.use_pregenerated = use_pregenerated

    @abstractmethod
    def generate_counterfactual(self, x, y_target):
        pass

    def name(self, exclude=None):
        class CustomEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, Type):
                    return str(obj)
                if isinstance(obj, DimensionalityReduction):
                    return obj.name()
                return super().default(obj)
        
        config = self.config

        if exclude:
            config = {k: v for k, v in config.items() if k != exclude}
    
        return self.__class__.__name__ + json.dumps(config, separators=(',', ':'), cls=CustomEncoder)

    def setup(self, X_train, y_train, X_calib, y_calib):
        pass

    def generate_counterfactuals(self, x_factuals, y_targets):
        x_cfs = np.zeros_like(x_factuals)

        if hasattr(self, 'grb_model'):
            sorted_indices = np.lexsort((x_factuals[:, 0], y_targets))
            x_factuals = x_factuals[sorted_indices]
            y_targets = y_targets[sorted_indices]
            self.grb_model.setParam('OutputFlag', 0)


        for i in tqdm(range(x_factuals.shape[0])):
            x_cfs[i] = self.generate_counterfactual(x_factuals[i], y_targets[i])
        
        if hasattr(self, 'grb_model'):
            inverse_indices = np.empty_like(sorted_indices)
            inverse_indices[sorted_indices] = np.arange(len(sorted_indices))
            x_cfs = x_cfs[inverse_indices]
            self.grb_model.setParam('OutputFlag', 1)

        return x_cfs
    
    def check_solution(self, input_mvar, y_target):
        identified_sol = input_mvar.X

        if np.argmax(self.model.predict(identified_sol)) == y_target:
            return identified_sol

        for i in range(len(identified_sol)):
            original_value = identified_sol[i]
            for perturbation in [-0.01, 0.01]:
                identified_sol[i] = original_value + perturbation
                if np.argmax(self.model.predict(identified_sol)) == y_target:
                    return identified_sol
            identified_sol[i] = original_value  

        return identified_sol