from woods.objectives.ERM import ERM
import torch

class VREx(ERM):
    """
    V-REx Objective from http://arxiv.org/abs/2003.00688
    """
    def __init__(self, model, dataset, optimizer, hparams):
        super(VREx, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.penalty_weight = self.hparams['penalty_weight']
        self.anneal_iters = self.hparams['anneal_iters']

        # Memory
        self.register_buffer('update_count', torch.tensor([0]))

    def update(self):

        # Define stuff
        penalty_weight = (self.penalty_weight   if self.update_count >= self.anneal_iters 
                                                else 1.0)

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        # Split input / target
        # X, Y = self.dataset.split_input(batch)

        # Get predict and get (logit, features)
        out, _ = self.predict(X)

        # Compute losses
        n_domains = self.dataset.get_nb_training_domains()
        domain_losses = self.dataset.loss_by_domain(out, Y, n_domains)

        # Compute objective
        mean = domain_losses.mean()
        penalty = ((domain_losses - mean).pow(2)).mean()
        objective = mean + penalty_weight * penalty

        # Reset Adam, because it doesn't like the sharp jump in gradient
        # magnitudes that happens at this step.
        if self.update_count == self.anneal_iters:
            self.optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=self.optimizer.param_groups[0]['lr'],
                weight_decay=self.optimizer.param_groups[0]['weight_decay'])

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

        # Update memory
        self.update_count += 1