import random
import torch
import torch.nn as nn


class MixStyle(nn.Module):
    """MixStyle.
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
          mix (str): how to mix.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha
        self.mix = mix
        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})'

    def set_activation_status(self, status=True):
        self._activated = status

    def update_mix_method(self, mix='random'):
        self.mix = mix

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        # BxKxCxL
        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=3, keepdim=True)
        var = x.var(dim=3, keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        if self.mix == 'random':
            # random shuffle
            perm = torch.randperm(B)

        elif self.mix == 'crossdomain':
            # split into two halves and swap the order
            # perm = torch.arange(B - 1, -1, -1) # inverse index
            # perm_b, perm_a = perm.chunk(2)
            # perm_b = perm_b[torch.randperm(B // 2)]
            # perm_a = perm_a[torch.randperm(B // 2)]
            # perm = torch.cat([perm_b, perm_a], 0)

            perm = torch.arange(B)  # 0, b-1
            perm_a, perm_b, perm_c = perm.chunk(3)  # split into three parts
            domain_batch_size = B // 3
            perm_a = perm_a[torch.randperm(domain_batch_size)]
            perm_b = perm_b[torch.randperm(domain_batch_size)]
            perm_c = perm_c[torch.randperm(domain_batch_size)]

            if random.random() < 0.5:
                perm = torch.cat([perm_b, perm_c, perm_a], dim=0)
            else:
                perm = torch.cat([perm_c, perm_a, perm_b], dim=0)

        else:
            raise NotImplementedError

        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed * sig_mix + mu_mix