# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03a_net.ipynb (unless otherwise specified).

__all__ = ['LinearBlock', 'MultilayerPerception', 'BaselineModel', 'ConvBlock', 'MultilayerConv', 'CounterfactualModel',
           'CounterfactualModel2Optimizers', 'CounterfactualModel2OptsNoPass', 'CounterfactualModelSeparate',
           'CounterfactualModelPosthoc', 'ConvCounterNet', 'Embed', 'TransCounterNet', 'AE', 'VAE', 'CHVAE']

# Cell

from .import_essentials import *
from .training_module import CounterfactualTrainingModulePosthoc
from .utils import *
from .training_module import *
from pytorch_lightning.callbacks import EarlyStopping

# Comes from 03b_counterfactual_net.ipynb, cell

class _LinearBlock(nn.Module):
    """ICML version"""
    def __init__(self, input_dim, out_dim, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, out_dim),
            nn.BatchNorm1d(num_features=out_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.block(x)


class LinearBlock(nn.Module):
    def __init__(self, input_dim, out_dim, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, out_dim),
            # nn.BatchNorm1d(num_features=out_dim),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.block(x)

class _MultilayerPerception(nn.Module):
    """ICML version"""
    def __init__(self, dims=[3, 100, 10]):
        super().__init__()
        layers  = []
        num_blocks = len(dims)
        for i in range(1, num_blocks):
            layers += [
                _LinearBlock(dims[i-1], dims[i])
            ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class MultilayerPerception(nn.Module):
    def __init__(self, dims=[3, 100, 10]):
        super().__init__()
        layers  = []
        num_blocks = len(dims)
        for i in range(1, num_blocks):
            layers += [
                LinearBlock(dims[i-1], dims[i])
            ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class BaselineModel(BaselineTrainingModule):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0]
        self.model = nn.Sequential(
            _MultilayerPerception(self.enc_dims),
            _MultilayerPerception(self.dec_dims),
            nn.Linear(self.dec_dims[-1], 1)
        )

    def model_forward(self, x):
        # x = ([],)
        x, = x
        y_hat = torch.sigmoid(self.model(x))
        return torch.squeeze(y_hat, -1)

# Comes from 03b_counterfactual_net.ipynb, cell

class ConvBlock(nn.Module):
    def __init__(self, input_dim, out_dim, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(input_dim, out_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(num_features=out_dim),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.block(x)

class MultilayerConv(nn.Module):
    def __init__(self, dims=[3, 100, 10]):
        super().__init__()
        layers  = []
        num_blocks = len(dims)
        for i in range(1, num_blocks):
            layers += [
                ConvBlock(dims[i-1], dims[i])
            ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Comes from 03b_counterfactual_net.ipynb, cell

class CounterfactualModel(CounterfactualTrainingModule):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0]
        assert self.enc_dims[-1] == self.exp_dims[0]

        self.encoder_model = MultilayerPerception(self.enc_dims)
        self.predictor = nn.Sequential(
            MultilayerPerception(self.dec_dims),
            nn.Linear(self.dec_dims[-1], 1)
        )
        self.explainer = nn.Sequential(
            MultilayerPerception(self.exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        x = self.encoder_model(x)
        # predicted y_hat
        y_hat = torch.sigmoid(self.predictor(x))
        # counterfactual example
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c

class CounterfactualModel2Optimizers(CounterfactualTrainingModule2Optimizers):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.encoder_model = MultilayerPerception(self.enc_dims)
        # predictor
        self.predictor = MultilayerPerception(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        exp_dims = [x for x in self.exp_dims]
        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]

        self.explainer = nn.Sequential(
            MultilayerPerception(exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        x = self.encoder_model(x)
        # predicted y_hat
        pred = self.predictor(x)
        y_hat = torch.sigmoid(self.pred_linear(pred))
        # counterfactual example
        x = torch.cat((x, pred), -1)
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c

class CounterfactualModel2OptsNoPass(CounterfactualTrainingModule2Optimizers):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.encoder_model = MultilayerPerception(self.enc_dims)
        # predictor
        self.predictor = MultilayerPerception(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        self.explainer = nn.Sequential(
            MultilayerPerception(self.exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        x = self.encoder_model(x)
        # predicted y_hat
        pred = self.predictor(x)
        y_hat = torch.sigmoid(self.pred_linear(pred))
        # counterfactual example
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c


class CounterfactualModelSeparate(CounterfactualTrainingModule2Optimizers):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.encoder_model = MultilayerPerception(self.enc_dims)
        # predictor
        self.predictor = MultilayerPerception(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        exp_dims = self.enc_dims + self.exp_dims[1:]
        self.explainer = nn.Sequential(
            MultilayerPerception(exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        p = self.encoder_model(x)
        # predicted y_hat
        pred = self.predictor(p)
        y_hat = torch.sigmoid(self.pred_linear(pred))
        # counterfactual example
        cf = self.explainer(x)
        return torch.squeeze(y_hat, -1), cf


class CounterfactualModelPosthoc(CounterfactualTrainingModulePosthoc):
    """Train in a post-hoc fashion, i.e., train predictive model first, then train explainer."""
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.encoder_model = MultilayerPerception(self.enc_dims)
        # predictor
        self.predictor = MultilayerPerception(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        exp_dims = [x for x in self.exp_dims]
        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]

        self.explainer = nn.Sequential(
            MultilayerPerception(exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        x = self.encoder_model(x)
        # predicted y_hat
        pred = self.predictor(x)
        y_hat = torch.sigmoid(self.pred_linear(pred))
        # counterfactual example
        x = torch.cat((x, pred), -1)
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c

# Comes from 03b_counterfactual_net.ipynb, cell
class ConvCounterNet(CounterfactualTrainingModule2Optimizers):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.encoder_model = MultilayerConv(self.enc_dims)
        # predictor
        self.predictor = MultilayerConv(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        exp_dims = [x for x in self.exp_dims]
        exp_dims[0] = self.exp_dims[0] + self.dec_dims[-1]

        self.explainer = nn.Sequential(
            MultilayerPerception(exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        x = x.unsqueeze(dim=-1)
        x = self.encoder_model(x)
        # predicted y_hat
        pred = self.predictor(x)
        y_hat = torch.sigmoid(self.pred_linear(pred.squeeze(-1)))
        # counterfactual example
        x = torch.cat((x, pred), 1).squeeze(-1)
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c

# Comes from 03b_counterfactual_net.ipynb, cell
class Embed(nn.Module):
    def __init__(self, emb_dims: int):
        super().__init__()
        self.weight = nn.Parameter(torch.empty((1, emb_dims)))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        return x @ self.weight

# Comes from 03b_counterfactual_net.ipynb, cell
class TransCounterNet(CounterfactualTrainingModule2Optimizers):
    def __init__(self, config):
        super().__init__(config)
        assert self.enc_dims[-1] == self.dec_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.dec_dims[0]})"
        assert self.enc_dims[-1] == self.exp_dims[0], f"(enc_dims[-1]={self.enc_dims[-1]}) != (exp_dims[0]={self.enc_dims[0]})"

        self.emb = Embed(emb_dims=8)
        self.encoder_model = nn.TransformerEncoderLayer(d_model=8, nhead=4)

        self.dec_dims = [8] + self.dec_dims
        # predictor
        self.predictor = MultilayerPerception(self.dec_dims)
        self.pred_linear = nn.Linear(self.dec_dims[-1], 1)
        # explainer
        exp_dims = list(self.exp_dims)
        exp_dims[0] = self.dec_dims[0] + self.dec_dims[-1]

        self.explainer = nn.Sequential(
            MultilayerPerception(exp_dims),
            nn.Linear(self.exp_dims[-1], self.enc_dims[0])
        )

    def model_forward(self, x):
        # append special token (-1)
        x = torch.cat((x, (torch.zeros((x.size(0), 1))-1)), dim=-1)
        x = x.unsqueeze(dim=-1)
        x = self.emb(x)
        x = self.encoder_model(x)
        x = x[:, -1, :]
        # predicted y_hat
        pred = self.predictor(x)
        y_hat = torch.sigmoid(self.pred_linear(pred))
        # counterfactual example
        x = torch.cat((x, pred), -1)
        c = self.explainer(x)
        return torch.squeeze(y_hat, -1), c

# Comes from 05a_baseline_algos.ipynb, cell

class AE(DataModule):
    def __init__(self, configs, encoded_size=5):
        super().__init__(configs)
        input_dim = configs['encoder_dims'][0]
        self.encoder_model = MultilayerPerception([input_dim, 20, 16, 14, 12, encoded_size])
        self.decoder_model = MultilayerPerception([encoded_size, 12, 14, 16, 20, input_dim])

    def forward(self, x):
        z = self.encoded(x)
        x_prime = self.decoder_model(z)
        return x_prime

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def encoded(self, x):
        return self.encoder_model(x)

    def training_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        x_prime = self(x)

        loss = F.mse_loss(x_prime, x, reduction='mean')

        self.log('train/loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        x_prime = self(x)

        loss = F.mse_loss(x_prime, x, reduction='mean')

        self.log('val/val_loss', loss)

        return loss

# Comes from 05a_baseline_algos.ipynb, cell
class VAE(pl.LightningModule):
    def __init__(self, input_dims, encoded_size=5):
        super().__init__()
        self.encoder_mean = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])
        self.encoder_var = MultilayerPerception([input_dims + 1, 20, 16, 14, 12, encoded_size])
        self.decoder_mean = MultilayerPerception([encoded_size + 1, 12, 14, 16, 20, input_dims])

    def encoder(self, x):
        mean = self.encoder_mean(x)
        logvar = 0.5+ self.encoder_var(x)
        return mean, logvar

    def decoder(self, z):
        mean = self.decoder_mean(z)
        return mean

    def sample_latent_code(self, mean, logvar):
        eps = torch.randn_like(logvar)
        return mean + torch.sqrt(logvar) * eps

    def normal_likelihood(self, x, mean, logvar, raxis=1):
        return torch.sum( -.5 * ((x - mean)*(1./logvar)*(x-mean) + torch.log(logvar) ), axis=1)

    def forward(self, x, c):
        """
        x: input instance
        c: target y
        """
        c = c.view(c.shape[0], 1).float()
        # c = torch.tensor(c).float()
        res = {}
        mc_samples = 50
        em, ev = self.encoder(torch.cat((x, c), 1))
        res['em'] = em
        res['ev'] = ev
        res['z'] = []
        res['x_pred'] = []
        res['mc_samples'] = mc_samples
        for i in range(mc_samples):
            z = self.sample_latent_code(em, ev)
            x_pred = self.decoder(torch.cat((z, c), 1))
            res['z'].append(z)
            res['x_pred'].append(x_pred)
        return res

    def compute_elbo(self, x, c, model):
        c= c.clone().detach().float()
        c=c.view(c.shape[0], 1)
        em, ev = self.encoder(torch.cat((x,c),1))
        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)

        z = self.sample_latent_code(em, ev)
        dm= self.decoder( torch.cat((z,c),1) )
        log_px_z = torch.tensor(0.0)

        x_pred= dm
        return torch.mean(log_px_z), torch.mean(kl_divergence), x, x_pred, model.predict(x_pred)

# Comes from 05a_baseline_algos.ipynb, cell
class CHVAE(pl.LightningModule):
    """
    https://github.com/carla-recourse/CARLA/blob/main/carla/recourse_methods/autoencoder/models.py
    """
    def __init__(self, input_dims, encoded_size=5):
        super().__init__()
        encoder = MultilayerPerception([input_dims, 20, 16, 14, 12])
        decoder = MultilayerPerception([encoded_size, 12, 14, 16, 20])

        self._mu_enc = nn.Sequential(encoder, nn.Linear(12, encoded_size))
        self._log_var_enc = nn.Sequential(encoder, nn.Linear(12, encoded_size))
        self.mu_dec = nn.Sequential(
            decoder, nn.Linear(20, input_dims), nn.BatchNorm1d(input_dims), nn.Sigmoid(),
        )
        self.log_var_dec  = nn.Sequential(
            decoder, nn.Linear(20, input_dims), nn.BatchNorm1d(input_dims), nn.Sigmoid(),
        )

    def encode(self, x):
        return self._mu_enc(x), self._log_var_enc(x)

    def decode(self, z):
        return self.mu_dec(z), self.log_var_dec(z)

    def __reparametrization_trick(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)  # the Gaussian random noise
        return mu + std * epsilon

    def forward(self, x):
        mu_z, log_var_z = self.encode(x)
        z_rep = self.__reparametrization_trick(mu_z, log_var_z)
        mu_x, log_var_x = self.decode(z_rep)

        return mu_x, log_var_x, z_rep, mu_z, log_var_z

    def regenerate(self, z):
        mu_x, log_var_x = self.decode(z)
        return mu_x

    def compute_loss(self, mse_loss, mu, logvar):
        MSE = mse_loss
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE + KLD