from woods.objectives.ERM import ERM
import torch


class GroupDRO(ERM):
    """
    GroupDRO
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(GroupDRO, self).__init__(model, dataset, optimizer, hparams)

        # Save hparams
        self.device = self.hparams['device']
        self.eta = hparams['eta']
        self.register_buffer("q", torch.Tensor())

        # Save training components
        self.model = model
        self.dataset = dataset
        self.optimizer = optimizer

        # Get some other useful info
        self.nb_training_domains = dataset.get_nb_training_domains()

    def predict(self, all_x):
        return self.model(all_x)

    def update(self):

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        if not len(self.q):
            print("hello, creating Q")
            self.q = torch.ones(self.nb_training_domains).to(self.device)

        # Split input / target
        # X, Y = self.dataset.split_input(batch)

        # Get predict and get (logit, features)
        out, _ = self.predict(X)

        # Compute losses
        domain_losses = self.dataset.loss_by_domain(out, Y, self.nb_training_domains)

        # Update weights
        for dom_i, dom_loss in enumerate(domain_losses):
            self.q[dom_i] *= (self.eta * dom_loss.data).exp()
        self.q /= self.q.sum()

        # Compute objective
        objective = torch.dot(domain_losses, self.q)

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()