import torch
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model

class ERM(SingleModelAlgorithm):
    def __init__(self, config, d_out, grouper, loss,
            metric, n_train_steps):
        model = initialize_model(config, d_out).to(config.device)
        # initialize module
        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
        )

    def objective(self, results):
        if results['y_pred'].shape[1] == 1:
            results['y_pred'] = torch.nn.Sigmoid()(results['y_pred']).squeeze(1)
            results['y_true'] = results['y_true'].float()
        return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False)
