from lampe.inference import BNRELoss

from .base import ModelFactory
from .nre import NREModel


class BNREFactory(ModelFactory):
    def __init__(self, config, benchmark, simulation_budget):
        config_run = config.copy()
        for idx in range(len(config["simulation_budgets"])):
            if config["simulation_budgets"][idx] == simulation_budget:
                break
        config_run["train_batch_size"] = config["train_batch_size"][idx]
        config_run["weight_decay"] = config["weight_decay"][idx]

        super().__init__(config_run, benchmark, simulation_budget, BNREModel)


class BNREModel(NREModel):
    def __init__(self, benchmark, model_path, config, normalization_constants):
        super().__init__(benchmark, model_path, config, normalization_constants)

    def get_loss_fct(self, config):
        return lambda estimator: BNRELoss(
            estimator, lmbda=config["regularization_strength"]
        )
