from alg.algs.ERM import ERM
import torch
import torch.nn.functional as F
import torch.autograd as autograd


class IRM(ERM):
    """
    Invariant Risk Minimization (IRM)
    """

    def __init__(self, args):
        super(IRM, self).__init__(args)

        # Hyper parameters
        self.penalty_weight = args.penalty_weight
        self.anneal_iters = args.anneal_iters

        # Memory
        self.penalty = 0
        self.register_buffer('update_count', torch.tensor([0]))

    @staticmethod
    def _irm_penalty(logits, y):
        device = y.device
        scale = torch.tensor(1.).to(device).requires_grad_()
        logits = logits.squeeze(dim=1) 
        # print(logits[::2].shape, y[::2].shape)
        loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
        loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
        grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
        grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

    def update(self, minibatches, opt, sch):

        # Define penalty value (Annealing)
        penalty_weight = (self.penalty_weight   if self.update_count >= self.anneal_iters 
                                                else 1.0)
        
        all_x = torch.cat([data[0].cuda().float() for data in minibatches])

        all_logits = self.network(all_x)
        all_logits_idx = 0
        domain_losses = torch.zeros(len(minibatches)).cuda()
        irm_penalty = torch.zeros(len(minibatches)).cuda()

        for i, data in enumerate(minibatches):
            logits = all_logits[all_logits_idx:all_logits_idx +
                                data[0].shape[0]]
            all_logits_idx += data[0].shape[0]
            nll = F.cross_entropy(logits, data[1].cuda().long())
            penalty = self._irm_penalty(logits, data[1].cuda().long())
            domain_losses[i] = nll
            irm_penalty[i] = penalty

        # Compute objective
        mean_irm_penalty = irm_penalty.mean()
        loss = domain_losses.mean() + (penalty_weight * mean_irm_penalty)


        opt.zero_grad()
        loss.backward()
        opt.step()
        if sch:
            sch.step()

        self.update_count += 1
        return {'loss': loss.item(), 'nll': domain_losses.mean().item(),
                'penalty': penalty_weight * mean_irm_penalty.item()}