from torch.nn import functional as F

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy


@TRAINER_REGISTRY.register()
class MiRe(TrainerX):
    """MiRe baseline."""

    def forward_backward(self, batch):
        input, label = self.parse_batch_train(batch)
        domain_label = batch['domain'].to(self.device)
        output, loss_CCR = self.model.backbone(input, domain_label, label)
        loss = F.cross_entropy(output, label)
        loss = loss + loss_CCR
        self.model_backward_and_update(loss)

        loss_summary = {
            "loss": loss.item(),
            "acc": compute_accuracy(output, label)[0].item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label
