from woods.objectives.IRM import IRM
import torch

class IB_IRM(IRM):
    """
    Invariance Principle Meets Information Bottleneck
 for Out-of-Distribution Generalization <https://arxiv.org/pdf/2106.06607>
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(IB_IRM, self).__init__(model, dataset, optimizer,
                                    hparams)
        # self.irm_lambda = self.hparams['irm_lambda']
        # self.irm_penalty_anneal_iters = self.hparams['irm_penalty_anneal_iters']
        self.ib_lambda = self.hparams['ib_lambda']
        self.ib_penalty_anneal_iters = self.hparams['ib_penalty_anneal_iters']


    def update(self):
        penalty_weight = (self.penalty_weight   if self.update_count >= self.anneal_iters 
                                                else 1.0)
        # penalty_weight = (self.irm_lambda if self.update_count
        #                                                 >= self.irm_penalty_anneal_iters else 0.0)  # todo

        ib_penalty_weight = (self.ib_lambda if self.update_count
                                                          >= self.ib_penalty_anneal_iters else 0.0)


        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        # Split input / target
        # X, Y = self.dataset.split_input(batch)

        # Get predict and get (logit, features)
        out, inter_logits = self.predict(X)
        # Compute losses
        n_domains = self.dataset.get_nb_training_domains() 
        domain_losses = self.dataset.loss_by_domain(out, Y, n_domains)

        # Create domain dimension in tensors. 
        #   e.g. for source domains: (ENVS * batch_size, ...) -> (ENVS, batch_size, ...)
        #        for time domains: (batch_size, ENVS, ...) -> (ENVS, batch_size, ...)
        out, labels = self.dataset.split_tensor_by_domains(out, Y, n_domains)
        # env_labels = self.dataset.split_tensor_by_domains(n_domains, Y)

        # Compute loss and penalty for each domains
        irm_penalty = torch.zeros(n_domains).to(self.device)
        for i, (env_out, env_labels) in enumerate(zip(out, labels)):
            irm_penalty[i] += self._irm_penalty(env_out, env_labels)

        # Compute objective
        irm_penalty = irm_penalty.mean()

        # IB loss
        var_loss = inter_logits.var(dim=0).mean()

        ib_loss = ib_penalty_weight * var_loss

        objective = domain_losses.mean() + (penalty_weight * irm_penalty) + ib_loss
        

        # if self.update_count == self.hparams['ib_penalty_anneal_iters']:
        #     var_loss = inter_logits.var(dim=0).mean()
        #     loss += ib_penalty_weight * var_loss

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

        # Update memory
        self.update_count += 1