from alg.algs.IRM import IRM
import torch
import torch.nn.functional as F

class IB_IRM(IRM):
    """
    Invariance Principle Meets Information Bottleneck
 for Out-of-Distribution Generalization <https://arxiv.org/pdf/2106.06607>
    """

    def __init__(self, args):
        super(IB_IRM, self).__init__(args)
        self.ib_lambda = args.ib_lambda
        self.ib_penalty_anneal_iters = args.ib_penalty_anneal_iters


    def update(self, minibatches, opt, sch):
        penalty_weight = (self.penalty_weight   if self.update_count >= self.anneal_iters 
                                                else 1.0)
        
        ib_penalty_weight = (self.ib_lambda if self.update_count
                                                          >= self.ib_penalty_anneal_iters else 0.0)
        
        all_x = torch.cat([data[0].cuda().float() for data in minibatches])

        features = self.featurizer(all_x)
        flattened_features = features.view(features.size(0), -1)
        all_logits = self.classifier(flattened_features)

        all_logits_idx = 0
        domain_losses = torch.zeros(len(minibatches)).cuda()
        irm_penalty = torch.zeros(len(minibatches)).cuda()

        for i, data in enumerate(minibatches):
            logits = all_logits[all_logits_idx:all_logits_idx +
                                data[0].shape[0]]
            all_logits_idx += data[0].shape[0]
            nll = F.cross_entropy(logits, data[1].cuda().long())
            penalty = self._irm_penalty(logits, data[1].cuda().long())
            domain_losses[i] = nll
            irm_penalty[i] = penalty

        # IB loss
        var_loss = flattened_features.var(dim=0).mean()
        ib_loss = ib_penalty_weight * var_loss
        loss = domain_losses.mean() + (penalty_weight * irm_penalty.mean()) + ib_loss

        # Back propagate
        opt.zero_grad()
        loss.backward()
        opt.step()
        if sch:
            sch.step()

        # Update memory
        self.update_count += 1

        return {
            'loss': loss.item(),
            'domain': domain_losses.mean().item(),
            'irm_penalty': irm_penalty.mean().item(),
            'ib': ib_loss.item()
        }