from woods.objectives.ERM import ERM
import torch
import torch.nn.functional as F
import torch.autograd as autograd


class IRM(ERM):
    """
    Invariant Risk Minimization (IRM)
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(IRM, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.penalty_weight = self.hparams['penalty_weight']
        self.anneal_iters = self.hparams['anneal_iters']

        # Memory
        self.penalty = 0
        self.register_buffer('update_count', torch.tensor([0]))

    @staticmethod
    def _irm_penalty(logits, y):
        device = y.device
        scale = torch.tensor(1.).to(device).requires_grad_()
        logits = logits.squeeze(dim=1) 
        y = y.squeeze(dim=1)
        # print(logits[::2].shape, y[::2].shape)
        loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
        loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
        grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
        grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

    def update(self):

        # Define penalty value (Annealing)
        penalty_weight = (self.penalty_weight   if self.update_count >= self.anneal_iters 
                                                else 1.0)

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        # Split input / target
        # X, Y = self.dataset.split_input(batch)

        # Get predict and get (logit, features)
        out, _ = self.predict(X)

        # Compute losses
        n_domains = self.dataset.get_nb_training_domains() 

        domain_losses = self.dataset.loss_by_domain(out, Y, n_domains)

        # Create domain dimension in tensors. 
        #   e.g. for source domains: (ENVS * batch_size, ...) -> (ENVS, batch_size, ...)
        #        for time domains: (batch_size, ENVS, ...) -> (ENVS, batch_size, ...)
        out, labels = self.dataset.split_tensor_by_domains(out, Y, n_domains)


        # Compute loss and penalty for each domains
        irm_penalty = torch.zeros(n_domains).to(self.device)
        for i, (env_out, env_labels) in enumerate(zip(out, labels)):
            irm_penalty[i] += self._irm_penalty(env_out, env_labels)

        # Compute objective
        irm_penalty = irm_penalty.mean()
        # print(domain_losses.mean(), irm_penalty)
        objective = domain_losses.mean() + (penalty_weight * irm_penalty)

        # Reset Adam, because it doesn't like the sharp jump in gradient
        # magnitudes that happens at this step.
        if self.update_count == self.anneal_iters:
            self.optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=self.optimizer.param_groups[0]['lr'],
                weight_decay=self.optimizer.param_groups[0]['weight_decay'])

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

        # Update memory
        self.update_count += 1