from baselines.DP_FERMI.dp_fermi.models import LogisticRegression
from keras.layers import TorchModuleWrapper
import keras


class LogisticRegrressionKeras(keras.Model):
    def __init__(self, input_num_attr, **kwargs):
        super().__init__(**kwargs)
        self.lr = TorchModuleWrapper(LogisticRegression(input_num_attr=input_num_attr))

    def call(self, inputs, **kwargs):
        return self.lr(inputs)[1]