from woods.objectives.ERM import ERM
import torch

class SD(ERM):
    """
    Gradient Starvation: A Learning Proclivity in Neural Networks
    Equation 25 from [https://arxiv.org/pdf/2011.09468.pdf]
    """
    def __init__(self, model, dataset, optimizer, hparams):
        super(SD, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.penalty_weight = self.hparams['penalty_weight']

    def update(self):

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        # Split input / target
        # X, Y = self.dataset.split_input(env_batches)

        # Get predict and get (logit, features)
        out, _ = self.predict(X)

        # Compute losses
        n_domains = self.dataset.get_nb_training_domains()
        domain_losses = self.dataset.loss_by_domain(out, Y, n_domains)

        # Create domain dimension in tensors:
        #   e.g. for source domains: (ENVS * batch_size, ...) -> (ENVS, batch_size, ...)
        #        for time domains: (batch_size, ENVS, ...) -> (ENVS, batch_size, ...) 
        domain_out, _ = self.dataset.split_tensor_by_domains(out, Y, n_domains)

        # Compute loss for each environment
        sd_penalty = torch.pow(out, 2).sum(dim=-1)

        # sd_penalty = torch.zeros(env_out.shape[0]).to(env_out.device)
        # for i in range(env_out.shape[0]):
        #     for t_idx in range(env_out.shape[2]):     # Number of time steps
        #         sd_penalty[i] += (env_out[i, :, t_idx, :] ** 2).mean()

        sd_penalty = sd_penalty.mean()
        objective = domain_losses.mean() + self.penalty_weight * sd_penalty

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()