import torch
from torch.nn import functional as F

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy

__all__ = ["DomainMix"]


@TRAINER_REGISTRY.register()
class DomainMix(TrainerX):
    """DomainMix.
    
    Dynamic Domain Generalization.

    https://github.com/MetaVisionLab/DDG
    """

    def __init__(self, cfg):
        super(DomainMix, self).__init__(cfg)
        self.mix_type = cfg.TRAINER.DOMAINMIX.TYPE
        self.alpha = cfg.TRAINER.DOMAINMIX.ALPHA
        self.beta = cfg.TRAINER.DOMAINMIX.BETA
        self.dist_beta = torch.distributions.Beta(self.alpha, self.beta)

    def forward_backward(self, batch):
        images, label_a, label_b, lam = self.parse_batch_train(batch)
        output = self.model(images)
        loss = lam * F.cross_entropy(
            output, label_a
        ) + (1-lam) * F.cross_entropy(output, label_b)
        self.model_backward_and_update(loss)

        loss_summary = {
            "loss": loss.item(),
            "acc": compute_accuracy(output, label_a)[0].item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        images = batch["img"]
        target = batch["label"]
        domain = batch["domain"]
        images = images.to(self.device)
        target = target.to(self.device)
        domain = domain.to(self.device)
        images, target_a, target_b, lam = self.domain_mix(
            images, target, domain
        )
        return images, target_a, target_b, lam

    def domain_mix(self, x, target, domain):
        lam = (
            self.dist_beta.rsample((1, ))
            if self.alpha > 0 else torch.tensor(1)
        ).to(x.device)

        # random shuffle
        perm = torch.randperm(x.size(0), dtype=torch.int64, device=x.device)
        if self.mix_type == "crossdomain":
            domain_list = torch.unique(domain)
            if len(domain_list) > 1:
                for idx in domain_list:
                    cnt_a = torch.sum(domain == idx)
                    idx_b = (domain != idx).nonzero().squeeze(-1)
                    cnt_b = idx_b.shape[0]
                    perm_b = torch.ones(cnt_b).multinomial(
                        num_samples=cnt_a, replacement=bool(cnt_a > cnt_b)
                    )
                    perm[domain == idx] = idx_b[perm_b]
        elif self.mix_type != "random":
            raise NotImplementedError(
                f"Chooses {'random', 'crossdomain'}, but got {self.mix_type}."
            )
        mixed_x = lam*x + (1-lam) * x[perm, :]
        target_a, target_b = target, target[perm]
        return mixed_x, target_a, target_b, lam
