import keras
from library.vern_core import *
from models.abstract_model import AbstractModel

class KerasModelEncoding(AbstractModel):
    def train(self, X_train, y_train):
        raise NotImplementedError("Use load_external instead")
   
    def load(self, save_path):
        self.model = keras.saving.load_model(save_path)
    
    def save(self, save_path):
        self.model.save(save_path)

    def predict(self, x):
        if len(x.shape) == 1:
            x = x.reshape(1, -1)
            return self.model(x).numpy()[0]
        return self.model(x).numpy()

    def gp_set_model_constraints(self, grb_model: gp.Model, input_mvar: gp.MVar) -> gp.MVar:
        kernels = []
        biases = []
        grb_layers = []
        
        input_layer = InputLayerPredefined(grb_model, input_mvar)
        grb_layers = [input_layer]
        idx = 1

        for layer in self.model.layers:
            if isinstance(layer, keras.src.layers.core.dense.Dense):
                activation = layer.get_config()["activation"]
                k = layer.get_weights()[0]
                kernels.append(k)
                b = layer.get_weights()[1]
                biases.append(b)

                grb_layers.append(Dense(grb_model, k, b, idx))
                idx += 1

                if activation == "relu":
                    grb_layers.append(ReLU(grb_model, idx))
                    idx += 1
                elif activation == "linear":
                    pass
                else:
                    raise ValueError("Unsupported activation function")
            else:
                print(type(layer))
                raise ValueError("Unsupported layer type")

        vmodel = VModel(grb_layers)
        vmodel.forward()

        return vmodel.layers[-1].var

    def gp_set_classification_constraint(self, grb_model: gp.Model, output_vars: gp.MVar, target_class: int, db_distance=1e-6) -> None:
        classification_constrs = []

        if output_vars.shape[0] == 1:
            #Single output
            assert target_class in [0, 1], "Target class must be 0 or 1 for a single logit output"
            
            if target_class == 0:
                c1 = grb_model.addConstr(output_vars[0] >= 0.5, name="Output class 0 constraint")
                classification_constrs.append(c1)
            else:
                c2 = grb_model.addConstr(output_vars[0] <= 0.5, name="Output class 1 constraint")
                classification_constrs.append(c2)
            
        else:
            #One-hot output
            for i in range(output_vars.shape[0]):
                if i != target_class:
                    c = grb_model.addConstr(output_vars[target_class] >= output_vars[i] + db_distance)
                    classification_constrs.append(c)

        return classification_constrs