import jax.numpy as jnp
from flax.typing import VariableDict

from jadex.distributions.base_distribution import Sample
from jadex.networks.variational.variational_network import VariationalNetwork


class BaselineModel(VariationalNetwork):

    @property
    def scale(self):
        return self.cfg.scale

    def fit_predict(
        self, variables: VariableDict, samples: dict[str, Sample], target: jnp.ndarray, train: bool = True
    ):
        baseline = self.apply(variables, samples)


def register_baseline_models():
    from .vision_baseline import VisionFeedForwardBaselineModel, VisionResNetBaselineModel

    VisionResNetBaselineModel.register()
    VisionFeedForwardBaselineModel.register()
