import torch
from algorithms.single_model_algorithm import SingleModelAlgorithm
from models.initializer import initialize_model


class GroupDRO(SingleModelAlgorithm):
    """
    Group distributionally robust optimization.

    Original paper:
        @inproceedings{sagawa2019distributionally,
          title={Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization},
          author={Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy},
          booktitle={International Conference on Learning Representations},
          year={2019}
        }
    """

    def __init__(
        self, config, d_out, grouper, loss, metric, n_train_steps, is_group_in_train
    ):
        # check config
        assert config.uniform_over_groups
        # initialize model
        model = initialize_model(config, d_out)
        # initialize module
        super().__init__(
            config=config,
            model=model,
            grouper=grouper,
            loss=loss,
            metric=metric,
            n_train_steps=n_train_steps,
        )
        # additional logging
        self.logged_fields.append("group_weight")
        # step size
        self.group_weights_step_size = config.group_dro_step_size
        # initialize adversarial weights
        self.group_weights = torch.zeros(grouper.n_groups)
        self.group_weights[is_group_in_train] = 1
        self.group_weights = self.group_weights / self.group_weights.sum()
        self.group_weights = self.group_weights.to(self.device)

    def process_batch(self, batch, unlabeled_batch=None):
        results = super().process_batch(batch)
        results["group_weight"] = self.group_weights
        return results

    def objective(self, results):
        """
        Takes an output of SingleModelAlgorithm.process_batch() and computes the
        optimized objective. For group DRO, the objective is the weighted average
        of losses, where groups have weights groupDRO.group_weights.
        Args:
            - results (dictionary): output of SingleModelAlgorithm.process_batch()
        Output:
            - objective (Tensor): optimized objective; size (1,).
        """
        group_losses, _, _ = self.loss.compute_group_wise(
            results["y_pred"],
            results["y_true"],
            results["g"],
            self.grouper.n_groups,
            return_dict=False,
        )
        return group_losses @ self.group_weights

    def _update(self, results, should_step=True):
        """
        Process the batch, update the log, and update the model, group weights, and scheduler.
        Args:
            - batch (tuple of Tensors): a batch of data yielded by data loaders
        Output:
            - results (dictionary): information about the batch, such as:
                - g (Tensor)
                - y_true (Tensor)
                - metadata (Tensor)
                - loss (Tensor)
                - metrics (Tensor)
                - objective (float)
        """
        # compute group losses
        group_losses, _, _ = self.loss.compute_group_wise(
            results["y_pred"],
            results["y_true"],
            results["g"],
            self.grouper.n_groups,
            return_dict=False,
        )
        # update group weights
        self.group_weights = self.group_weights * torch.exp(
            self.group_weights_step_size * group_losses.data
        )
        self.group_weights = self.group_weights / (self.group_weights.sum())
        # save updated group weights
        results["group_weight"] = self.group_weights
        # update model
        super()._update(results, should_step=should_step)
