from woods.objectives.Basic import Objective


class ERM(Objective):
    """
    Empirical Risk Minimization (ERM)
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(ERM, self).__init__(hparams)

        # Save hparams
        self.device = self.hparams['device']

        # Save training components
        self.model = model
        self.dataset = dataset
        self.optimizer = optimizer

        # Get some other useful info
        self.nb_training_domains = dataset.get_nb_training_domains()

    def predict(self, all_x):
        return self.model(all_x)

    def update(self):

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        # Split into input / target
        # X, Y = self.dataset.split_input(batch)
        # print("input shape:", X.shape)

        # Get predict and get (logit, features)
        out, _ = self.predict(X)
        # print("output shape", out.shape)
        # print("y", Y.shape)

        # Compute mean loss
        domain_losses = self.dataset.loss_by_domain(out, Y, self.nb_training_domains)
        # # print("domain_losses shape: ", domain_losses.shape)

        # Compute objective
        objective = domain_losses.mean()
        
        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()
        # print("cuda")
        # Clear CUDA cache
        # torch.cuda.empty_cache()