from woods.objectives.ERM import ERM
import torch
import torch.nn as nn
import torch.nn.functional as F
from woods.objectives.objective_utils import Classifier

class MockBatchNorm1d(nn.Module):
    """
    Mock BatchNorm1d layer to solve the problem of batch size 1 in BatchNorm1d
    """
    def __init__(self, dim):
        super(MockBatchNorm1d, self).__init__()
        self.batch_norm = nn.BatchNorm1d(dim)

    def forward(self, x):
        if self.training and x.size(0) == 1:
            x = x * self.batch_norm.weight + self.batch_norm.bias
        else:
            x = self.batch_norm(x)
        return x

class IIB(ERM):
    """
    Invariant Information Bottleneck for Domain Generalization <https://arxiv.org/pdf/2106.06333>
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(IIB, self).__init__(model, dataset, optimizer,
                                    hparams)
        feat_dim = self.model.feat_dim
        self.dataset = dataset
        num_classes = self.dataset.OUTPUT_SIZE
        num_domains = self.dataset.get_nb_training_domains()
        # VIB archs
        if hparams['enable_bn']:
            self.encoder = torch.nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                MockBatchNorm1d(feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim),
                MockBatchNorm1d(feat_dim),
                nn.ReLU(inplace=True)
            ).to(self.device)
        else:
            self.encoder = torch.nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True)
            ).to(self.device)
        self.fc3_mu = nn.Linear(feat_dim, feat_dim).to(self.device)  # output = CNN embedding latent variables
        self.fc3_logvar = nn.Linear(feat_dim, feat_dim).to(self.device)  # output = CNN embedding latent variables
        # Inv Risk archs
        self.inv_classifier = Classifier(self.model.feat_dim, num_classes,
                                                  self.hparams['nonlinear_classifier']).to(self.device)
        self.env_classifier = Classifier(self.model.feat_dim + 1, num_classes,
                                                  self.hparams['nonlinear_classifier']).to(self.device)
        self.domain_indx = [torch.full((hparams['batch_size'], 1), indx) for indx in range(num_domains)]
        self.optimizer = torch.optim.Adam(
            list(self.model.parameters()) + list(self.inv_classifier.parameters()) + list(
                self.env_classifier.parameters()) + list(self.encoder.parameters()) + list(
                self.fc3_mu.parameters()) + list(self.fc3_logvar.parameters()),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

    def encoder_fun(self, res_feat):
        latent_z = self.encoder(res_feat)
        mu = self.fc3_mu(latent_z)
        logvar = self.fc3_logvar(latent_z)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar / 2)
            eps = torch.randn_like(std)
            return torch.add(torch.mul(std, eps), mu)
        else:
            return mu

    def update(self):
        # Put model into training mode
        self.model.train()
        X, Y = self.dataset.get_next_batch()

        embeddings = torch.cat([curr_dom_embed for curr_dom_embed in self.domain_indx]).to(self.device)

        _, all_z = self.model(X)

        if all_z.dim() > 2:
            # all_z:[B,1,H]
            all_z = all_z.squeeze(dim=1)
        # encode feature to sampling vector: \mu, \var
        mu, logvar = self.encoder_fun(all_z)
        all_z = self.reparameterize(mu, logvar)

        # calculate loss by parts
        ib_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        y = Y.squeeze(dim=-1)
        inv_loss = F.cross_entropy(self.inv_classifier(all_z), y)
        env_loss = F.cross_entropy(self.env_classifier(torch.cat([all_z, embeddings], 1)), y)

        # print("inv_loss", inv_loss, "env_loss", env_loss, "ib_loss", self.hparams['lambda_beta'] * ib_loss, "lambda_inv_risks", self.hparams['lambda_inv_risks'] * (
        #         inv_loss - env_loss) ** 2)
        # use beta to balance the info loss.
        objective = inv_loss + env_loss + self.hparams['lambda_beta'] * ib_loss + self.hparams['lambda_inv_risks'] * (
                inv_loss - env_loss) ** 2
        # or (inv_loss - env_loss) ** 2
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

    def predict(self, x):
        _, z = self.model(x)
        if z.dim() > 2:
            # all_z:[B,1,H]
            z = z.squeeze(dim=1)
        mu, logvar = self.encoder_fun(z)
        z = self.reparameterize(mu, logvar)
        y = self.inv_classifier(z)
        return y.unsqueeze(dim=1), z