from alg.algs.ERM import ERM
import torch
import torch.nn as nn
import torch.nn.functional as F

def Classifier(in_features, out_features, is_nonlinear=False):
    if is_nonlinear:
        return torch.nn.Sequential(
            torch.nn.Linear(in_features, in_features // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features // 2, in_features // 4),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features // 4, out_features))
    else:
        return torch.nn.Linear(in_features, out_features)

class MockBatchNorm1d(nn.Module):
    """
    Mock BatchNorm1d layer to solve the problem of batch size 1 in BatchNorm1d
    """
    def __init__(self, dim):
        super(MockBatchNorm1d, self).__init__()
        self.batch_norm = nn.BatchNorm1d(dim)

    def forward(self, x):
        if self.training and x.size(0) == 1:
            x = x * self.batch_norm.weight + self.batch_norm.bias
        else:
            x = self.batch_norm(x)
        return x

class IIB(ERM):
    """
    Invariant Information Bottleneck for Domain Generalization <https://arxiv.org/pdf/2106.06333>
    """

    def __init__(self, args):
        super(IIB, self).__init__(args)
        
        feat_dim = self.featurizer.in_features
        num_classes = args.num_classes
        num_domains = args.domain_num - 1
        enable_bn = args.enable_bn
        nonlinear_classifier = args.nonlinear_classifier
        weight_decay = args.weight_decay
        lr = args.lr
        batch_size = args.batch_size
        self.lambda_beta = args.lambda_beta
        self.lambda_inv_risks = args.lambda_inv_risks
        
        # VIB archs
        if enable_bn:
            self.encoder = torch.nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                MockBatchNorm1d(feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim),
                MockBatchNorm1d(feat_dim),
                nn.ReLU(inplace=True)
            )
        else:
            self.encoder = torch.nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True)
            )
        self.fc3_mu = nn.Linear(feat_dim, feat_dim)  # output = CNN embedding latent variables
        self.fc3_logvar = nn.Linear(feat_dim, feat_dim)  # output = CNN embedding latent variables
        # Inv Risk archs
        self.inv_classifier = Classifier(feat_dim, num_classes, nonlinear_classifier)
        self.env_classifier = Classifier(feat_dim + 1, num_classes, nonlinear_classifier)

        self.domain_indx = [torch.full((batch_size, 1), indx) for indx in range(num_domains)]
        self.optimizer = torch.optim.Adam(
            list(self.featurizer.parameters()) + list(self.inv_classifier.parameters()) + list(
                self.env_classifier.parameters()) + list(self.encoder.parameters()) + list(
                self.fc3_mu.parameters()) + list(self.fc3_logvar.parameters()), lr=lr, weight_decay=weight_decay)

    def encoder_fun(self, res_feat):
        latent_z = self.encoder(res_feat)
        mu = self.fc3_mu(latent_z)
        logvar = self.fc3_logvar(latent_z)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar / 2)
            eps = torch.randn_like(std)
            return torch.add(torch.mul(std, eps), mu)
        else:
            return mu

    def update(self, minibatches, opt, sch):
        X = torch.cat([data[0].cuda().float() for data in minibatches])
        Y = torch.cat([data[1].cuda().long() for data in minibatches])

        embeddings = torch.cat([curr_dom_embed for curr_dom_embed in self.domain_indx]).to(self.device)

        all_z = self.featurizer(X)
        # encode feature to sampling vector: \mu, \var
        mu, logvar = self.encoder_fun(all_z)
        all_z = self.reparameterize(mu, logvar)

        # calculate loss by parts
        ib_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        y = Y.squeeze(dim=-1)
        inv_loss = F.cross_entropy(self.inv_classifier(all_z), y)
        env_loss = F.cross_entropy(self.env_classifier(torch.cat([all_z, embeddings], 1)), y)

        # print("inv_loss", inv_loss, "env_loss", env_loss, "ib_loss", self.hparams['lambda_beta'] * ib_loss, "lambda_inv_risks", self.hparams['lambda_inv_risks'] * (
        #         inv_loss - env_loss) ** 2)
        # use beta to balance the info loss.
        loss = inv_loss + env_loss + self.lambda_beta * ib_loss + self.lambda_inv_risks * (
                inv_loss - env_loss) ** 2
        # or (inv_loss - env_loss) ** 2
        # Back propagate
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {
            'loss': loss.item(),
            'inv': inv_loss.item(),
            'env': env_loss.item(),
            'ib': ib_loss.item()
        }

    def predict(self, x):
        z = self.featurizer(x)
        mu, logvar = self.encoder_fun(z)
        z = self.reparameterize(mu, logvar)
        y = self.inv_classifier(z)
        return y