from woods.objectives.ERM import ERM
import torch


class IB_ERM(ERM):
    """Information Bottleneck based ERM on feature with conditionning"""

    def __init__(self, model, dataset, optimizer, hparams):
        super(IB_ERM, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.ib_weight = self.hparams['ib_weight']

        # Memory
        self.register_buffer('update_count', torch.tensor([0]))

    def update(self):

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        # Split input / target
        # X, Y = self.dataset.split_input(env_batches)

        # Get predict and get (logit, features)
        out, out_features = self.predict(X)

        # Compute losses
        n_domains = self.dataset.get_nb_training_domains()
        domain_losses = self.dataset.loss_by_domain(out, Y, n_domains)

        # Create domain dimension in tensors. 
        #   e.g. for source domains: (ENVS * batch_size, ...) -> (ENVS, batch_size, ...)
        domain_features, _ = self.dataset.split_tensor_by_domains(out_features, Y, n_domains)

        # For each environment, compute penalty
        ib_penalty = torch.zeros(n_domains).to(domain_losses.device)
        for i, d_feat in enumerate(domain_features):
            ib_penalty[i] = d_feat.var(dim=0).mean()
        
        objective = domain_losses.mean() + (self.ib_weight * ib_penalty.mean())

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

        # Update memory
        self.update_count += 1