from woods.objectives.ERM import ERM
import torch
import torch.nn.functional as F

class ANDMask(ERM):
    """
    Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329]
    AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary]
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(ANDMask, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.tau = self.hparams['tau']

    def mask_grads(self, tau, gradients, params):

        for param, grads in zip(params, gradients):
            if param.requires_grad==False:
                continue
            grads = torch.stack(grads, dim=0)
            grad_signs = torch.sign(grads)
            mask = torch.mean(grad_signs, dim=0).abs() >= self.tau
            mask = mask.to(torch.float32)
            avg_grad = torch.mean(grads, dim=0)

            mask_t = (mask.sum() / mask.numel())
            param.grad = mask * avg_grad
            param.grad *= (1. / (1e-10 + mask_t))

    def update(self):
        self.model.train()
        X, Y = self.dataset.get_next_batch()

        out, out_features = self.predict(X)
        n_domains = self.dataset.get_nb_training_domains()
        out, labels = self.dataset.split_tensor_by_domains(out, Y, n_domains)

        # Compute loss for each environment 
        env_losses = torch.zeros(out.shape[0]).to(self.device)
        for i in range(out.shape[0]):
            for t_idx in range(out.shape[2]):     # Number of time steps
                env_losses[i] += F.cross_entropy(out[i, :, t_idx, :], labels[i,:,t_idx])

        # Compute gradients for each env
        param_gradients = [[] for _ in self.model.parameters()]
        for env_loss in env_losses:
            env_grads = []
            for param in self.model.parameters():
                if param.requires_grad:
                    env_grads.append(torch.autograd.grad(env_loss, param, retain_graph=True)[0])
                else:
                    env_grads.append(None)
            for grads, env_grad in zip(param_gradients, env_grads):
                grads.append(env_grad)
            
        # Back propagate
        self.optimizer.zero_grad()
        self.mask_grads(self.tau, param_gradients, self.model.parameters())
        self.optimizer.step()