# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05a_baseline_algos.ipynb (unless otherwise specified).

__all__ = ['pl_logger', 'Clamp', 'ExplainerBase', 'VanillaCF', 'DiverseCF', 'ProtoCF', 'VAE_CF', 'CCHVAE',
           'CounteRGANTrainingModule', 'CounteRGAN']

# Cell
from .import_essentials import *
from .utils import *
from .train import *
from .training_module import *
from .net import *
# from counterfactual.evaluate import *

from torch.nn.parameter import Parameter
from pytorch_lightning.metrics.functional.classification import *

pl_logger = logging.getLogger('lightning')

# Cell

class Clamp(torch.autograd.Function):
    """
    Clamp parameter to [0, 1]
    code from: https://discuss.pytorch.org/t/regarding-clamped-learnable-parameter/58474/4
    """
    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0, max=1)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

# Cell

class ExplainerBase(nn.Module):
    def __init__(self, x: torch.tensor, model: pl.LightningModule):
        super().__init__()
        self.model = model
        self.model.freeze()
        self.x = x
        self.clamp = Clamp()

    def forward(self):
        raise NotImplementedError

    def compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for i in range(self.n_cfs):
            for col in self.model.cat_arrays:
                cat_idx_end = cat_idx + len(col)
                regularization_loss += torch.pow((torch.sum(self.cf[i][cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def configure_optimizers(self):
        return torch.optim.Adam([self.cf], lr=0.001)

    def generate_cf(self, n_iters):
        raise NotImplementedError

# Cell

class VanillaCF(ExplainerBase):
    def __init__(self, x: torch.tensor, model: BaselineModel):
        """vanilla version of counterfactual generation
            - link: https://doi.org/10.2139/ssrn.3063289

        Args:
            x (torch.tensor): input instance
            model (BaselineModel): black-box model
        """
        super().__init__(x, model)
        self.cf = nn.Parameter(self.x.clone(), requires_grad=True)

    def forward(self):
        cf = self.cf * 1.0
        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), False)
        # return cf

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for col in self.model.cat_arrays:
            cat_idx_end = cat_idx + len(col)
            regularization_loss += torch.pow((torch.sum(self.cf[cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def _loss_functions(self, x, c):
        # target
        y_pred = self.model.predict(x)
        y_prime = torch.ones(y_pred.shape) - y_pred

        c_y = self.model(c)
        l_1 = F.binary_cross_entropy(c_y, y_prime.float())
        l_2 = F.mse_loss(x, c)
        return l_1, l_2

    def _loss_compute(self, l_1, l_2):
        return 1.0 * l_1 + 0.5 * l_2

    def generate_cf(self, n_iters, debug: bool = False):
        optim = self.configure_optimizers()
        for i in range(n_iters):
            c = self()
            l_1, l_2 = self._loss_functions(self.x, c)
            loss = self._loss_compute(l_1, l_2)
            optim.zero_grad()
            loss.backward()
            optim.step()

            if debug and i % 100 == 0:
                print(f"iter: {i}, loss: {loss.item()}")

            # contrain to [0,1]
            self.clamp.apply(self.cf)

        cf = self.cf * 1.0
        self.clamp.apply(self.cf)
        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), True)

# Cell

class DiverseCF(ExplainerBase):
    def __init__(self, x: torch.tensor, model: CounterfactualTrainingModule):
        """diverse counterfactual explanation
            - link: https://doi.org/10.1145/3351095.3372850

        Args:
            x (torch.tensor): input instance
            model (CounterfactualTrainingModule): black-box model
        """
        self.n_cfs = 5
        super().__init__(x, model)
        # self.cf = nn.Parameter(self.x.repeat(self.n_cfs, 1), requires_grad=True)
        self.cf = nn.Parameter(torch.rand(self.n_cfs, self.x.size(1)), requires_grad=True)

    def forward(self):
        cf = self.cf * 1.0
        return torch.clamp(cf, 0, 1)

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def _compute_dist(self, x1, x2):
        return torch.sum(torch.abs(x1 - x2), dim = 0)

    def _compute_proximity_loss(self):
        """Compute the second part (distance from x1) of the loss function."""
        proximity_loss = 0.0
        for i in range(self.n_cfs):
            proximity_loss += self.compute_dist(self.cf[i], self.x1)
        return proximity_loss/(torch.mul(len(self.minx[0]), self.total_CFs))

    def _dpp_style(self, cf):
        det_entries = torch.ones(self.n_cfs, self.n_cfs)
        for i in range(self.n_cfs):
            for j in range(self.n_cfs):
                det_entries[i, j] = self._compute_dist(cf[i], cf[j])

        # implement inverse distance
        det_entries = 1.0 / (1.0 + det_entries)
        det_entries += torch.eye(self.n_cfs) * 0.0001
        return torch.det(det_entries)

    def _compute_diverse_loss(self, c):
        return self._dpp_style(c)

    def _compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for i in range(self.n_cfs):
            for col in self.model.cat_arrays:
                cat_idx_end = cat_idx + len(col)
                regularization_loss += torch.pow((torch.sum(self.cf[i][cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def _loss_functions(self, x, c):
        # target
        y_pred = self.model.predict(x)
        y_prime = torch.ones(y_pred.shape) - y_pred

        c_y = self.model(c)
        # yloss
        l_1 = hinge_loss(input=c_y, target=y_prime.float())
        # proximity loss
        l_2 = l1_mean(x, c)
        # diverse loss
        l_3 = self._compute_diverse_loss(c)
        # categorical penalty
        l_4 = self._compute_regularization_loss()
        return l_1, l_2, l_3, l_4

    def _compute_loss(self, *loss_f):
        return sum(loss_f)

    def generate_cf(self, n_iters, debug: bool = False):
        optim = self.configure_optimizers()
        for i in range(n_iters):
            c = self()

            l_1, l_2, l_3, l_4 = self._loss_functions(self.x, c)
            loss = self._compute_loss(l_1, l_2, l_3, l_4)
            optim.zero_grad()
            loss.backward()
            optim.step()

            if  debug and i % 100 == 0:
                print(f"iter: {i}, loss: {loss.item()}")

            # contrain to [0,1]
            self.clamp.apply(self.cf)

        cf = self.cf * 1.0
        cf = torch.clamp(cf, 0, 1)
        # return cf[0]
        return cat_normalize(cf[0].view(1, -1), self.model.cat_arrays, len(self.model.continous_cols), True)

# Cell

class ProtoCF(ExplainerBase):
    def __init__(self, x: torch.tensor, model: pl.LightningModule, train_loader: DataLoader, ae: AE):
        """vanilla version of counterfactual generation
            - link: https://doi.org/10.2139/ssrn.3063289

        Args:
            x (torch.tensor): input instance
            model (pl.LightningModule): black-box model
        """
        super().__init__(x, model)
        self.cf = nn.Parameter(self.x.clone(), requires_grad=True)
        self.sampled_data, _ = next(iter(train_loader))
        self.sampled_label = self.model.predict(self.sampled_data)
        self.ae = ae
        self.ae.freeze()

    def forward(self):
        cf = self.cf * 1.0
        # return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), False)
        return cf

    def configure_optimizers(self):
        return torch.optim.RMSprop([self.cf], lr=0.001)

    def compute_regularization_loss(self):
        cat_idx = len(self.model.continous_cols)
        regularization_loss = 0.
        for col in self.model.cat_arrays:
            cat_idx_end = cat_idx + len(col)
            regularization_loss += torch.pow((torch.sum(self.cf[cat_idx: cat_idx_end]) - 1.0), 2)
        return regularization_loss

    def proto(self, data):
        return self.ae.encoded(data).mean(axis=0).view(1, -1)

    def _loss_functions(self, x, c):
        # target
        y_pred = self.model.predict(x)
        y = torch.ones(y_pred.shape) - y_pred

        data = self.sampled_data[self.sampled_label == y]

        l_1 = F.binary_cross_entropy(self.model(c), y)
        l_2 = 0.1 * F.l1_loss(x, c) + F.mse_loss(x, c)
        l_3 = F.mse_loss(self.ae.encoded(c), self.proto(data))

        return l_1, l_2, l_3

    def _loss_compute(self, l_1, l_2, l_3):
        return l_1 + l_2 + l_3 #+ self.compute_regularization_loss()

    def generate_cf(self, n_iters, debug: bool = False):
        optim = self.configure_optimizers()
        for i in range(n_iters):
            c = self()

            l_1, l_2, l_3 = self._loss_functions(self.x, c)
            loss = self._loss_compute(l_1, l_2, l_3)
            optim.zero_grad()
            loss.backward()
            optim.step()

            if debug and i % 100 == 0:
                print(f"iter: {i}, loss: {loss.item()}")

            # contrain to [0,1]
            self.clamp.apply(self.cf)

        cf = self.cf * 1.0
        self.clamp.apply(self.cf)
        # return cf
        return cat_normalize(cf, self.model.cat_arrays, len(self.model.continous_cols), True)

# Cell
class VAE_CF(CounterfactualTrainingModule):
    def __init__(self, config: Dict, model: pl.LightningModule):
        """
        config: basic configs
        model: the black-box model to be explained
        """
        super().__init__(config)
        self.model = model
        self.model.freeze()
        self.vae = VAE(input_dims=self.enc_dims[0])
        # validity_reg set to 42.0
        # according to https://interpret.ml/DiCE/notebooks/DiCE_getting_started_feasible.html#Generate-counterfactuals-using-a-VAE-model
        self.validity_reg = config['validity_reg'] if 'validity_reg' in config.keys() else 1.0

    def model_forward(self, x):
        """lazy implementation since this method is actually not needed"""
        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1 - self.model.predict(x), self.model)
        # return y, c
        return cf_label, x_pred

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def predict(self, x):
        return self.model.predict(x)

    def compute_loss(self, out, x, y):
        em = out['em']
        ev = out['ev']
        z = out['z']
        dm = out['x_pred']
        mc_samples = out['mc_samples']
        #KL Divergence
        kl_divergence = 0.5*torch.mean(em**2 + ev - torch.log(ev) - 1, axis=1)

        #Reconstruction Term
        #Proximity: L1 Loss
        x_pred = dm[0]
        cat_idx = len(self.continous_cols)
        # recon_err = - \
        #     torch.sum(torch.abs(x[:, cat_idx:-1] -
        #                         x_pred[:, cat_idx:-1]), axis=1)
        recon_err = - torch.sum(torch.abs(x - x_pred), axis=1)

        # Sum to 1 over the categorical indexes of a feature
        for col in self.cat_arrays:
            cat_end_idx = cat_idx + len(col)
            temp = - \
                torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))
            recon_err += temp

        #Validity
        c_y = self.model(x_pred)
        validity_loss = torch.zeros(1, device=self.device)
        validity_loss += hinge_loss(input=c_y, target=y.float())

        for i in range(1, mc_samples):
            x_pred = dm[i]

            # recon_err += - \
            #     torch.sum(torch.abs(x[:, cat_idx:-1] -
            #                         x_pred[:, cat_idx:-1]), axis=1)
            recon_err += - torch.sum(torch.abs(x - x_pred), axis=1)

            # Sum to 1 over the categorical indexes of a feature
            for col in self.cat_arrays:
                cat_end_idx = cat_idx + len(col)
                temp = - \
                    torch.abs(1.0 - x_pred[:, cat_idx: cat_end_idx].sum(axis=1))
                recon_err += temp

            #Validity
            c_y = self.model(x_pred)
            validity_loss += hinge_loss(c_y, y.float())

        recon_err = recon_err / mc_samples
        validity_loss = -1 * self.validity_reg * validity_loss / mc_samples

        return -torch.mean(recon_err - kl_divergence) - validity_loss


    def training_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        y_hat = self.model.predict(x)
        # target
        y = 1.0 - y_hat

        out = self.vae(x, y)
        loss = self.compute_loss(out, x, y)

        self.log('train/loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        x, _ = batch
        # prediction
        y_hat = self.model.predict(x)
        # target
        y = 1.0 - y_hat

        out = self.vae(x, y)
        loss = self.compute_loss(out, x, y)

        _, _, _, x_pred, cf_label = self.vae.compute_elbo(x, y, self.model)

        cf_proximity = torch.abs(x - x_pred).sum(dim=1).mean()
        cf_accuracy = accuracy(cf_label, y)

        self.log('val/val_loss', loss)
        self.log('val/proximity', cf_proximity)
        self.log('val/cf_accuracy', cf_accuracy)

        return loss

    def validation_epoch_end(self, val_outs):
        return

    def generate_cf(self, x):
        self.vae.freeze()
        y_hat = self.model.predict(x)
        recon_err, kl_err, x_true, x_pred, cf_label = self.vae.compute_elbo(x, 1.-y_hat, self.model)
        return self.model.cat_normalize(x_pred, hard=True)

# Cell
class CCHVAE(CounterfactualTrainingModule):
    """
    Refer to https://github.com/carla-recourse/CARLA/blob/main/carla/recourse_methods/catalog/cchvae/model.py
    """
    def __init__(self, config: Dict, model: pl.LightningModule):
        """
        config: basic configs
        model: the black-box model to be explained
        """
        super().__init__(config)
        self.model = model
        self.model.freeze()
        self.vae = CHVAE(input_dims=self.enc_dims[0])

    def model_forward(self, x):
        """lazy implementation as it is not needed"""
        MU_X_eval, LOG_VAR_X_eval, Z_ENC_eval, MU_Z_eval, LOG_VAR_Z_eval = self.vae(x)
        return MU_X_eval, LOG_VAR_X_eval

    def predict(self, x):
        return self.model.predict(x)

    def _hyper_sphere_coordindates(
        self, x, high: int, low: int, n_search_samples: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        :param n_search_samples: int > 0
        :param x: input point array
        :param high: float>= 0, h>l; upper bound
        :param low: float>= 0, l<h; lower bound
        :return: candidate counterfactuals & distances
        """
        delta_instance = torch.randn(n_search_samples, x.size(1))
        dist = (
            torch.rand(n_search_samples) * (high - low) + low
        )  # length range [l, h)
        norm_p = torch.norm(delta_instance, p=1, dim=1)
        d_norm = torch.divide(dist, norm_p).reshape(-1, 1)  # rescale/normalize factor
        delta_instance = torch.multiply(delta_instance, d_norm)
        candidate_counterfactuals = x + delta_instance
        return candidate_counterfactuals, dist

    def generate_cf(self, x):
        # params
        n_search_samples = 300; count = 0; max_iter = 1000; step=0.1
        low = 0; high = step

        self.vae.freeze()
        y_hat = self.model.predict(x)

        # vectorize z
        z = self.vae.encode(x)[0]
        z_rep = torch.repeat_interleave(
            z.reshape(1, -1), n_search_samples, dim=0
        )

        candidate_dist = []
        x_ce: Union[np.ndarray, torch.Tensor] = torch.tensor([])

        while count <= max_iter:
            count = count + 1

            # STEP 1 -- SAMPLE POINTS on hyper sphere around instance
            latent_neighbourhood, _ = self._hyper_sphere_coordindates(z_rep, high, low, n_search_samples)
            x_ce = self.vae.decode(latent_neighbourhood)[0]

            x_ce = self.model.cat_normalize(x_ce, hard=True)
            x_ce = x_ce.clip(0, 1)

            # STEP 2 -- COMPUTE l1 norms
            distances = torch.abs((x_ce - x)).sum(dim=1)

            # counterfactual labels
            y_candidate = self.model.predict(x_ce)
            indeces = torch.where(y_candidate != y_hat)
            candidate_counterfactuals = x_ce[indeces]
            candidate_dist = distances[indeces]

            if len(candidate_dist) == 0:
                # no candidate found & push search range outside
                low = high
                high = low + step
            elif len(candidate_dist) > 0:
                # certain candidates generated
                min_index = np.argmin(candidate_dist)
                # return candidate_counterfactuals[min_index]
                return candidate_counterfactuals[0]
        return x_ce[0]

    def training_step(self, batch, batch_idx):
        x, y = batch
        MU_X_eval, LOG_VAR_X_eval, Z_ENC_eval, MU_Z_eval, LOG_VAR_Z_eval = self.vae(x)

        reconstruction = MU_X_eval
        mse_loss = F.mse_loss(reconstruction, x)
        loss = self.vae.compute_loss(mse_loss, MU_Z_eval, LOG_VAR_Z_eval)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        MU_X_eval, LOG_VAR_X_eval, Z_ENC_eval, MU_Z_eval, LOG_VAR_Z_eval = self.vae(x)

        reconstruction = self.cat_normalize(MU_X_eval)
        mse_loss = F.mse_loss(reconstruction, x)

        self.log('val/val_loss', mse_loss)
        return mse_loss

    def validation_epoch_end(self, val_outs):
        return

# Cell
class CounteRGANTrainingModule(DataModule):
    def __init__(self, config: Dict, model, target_class: int):
        super().__init__(config)
        self.model = model
        self.model.freeze()

        if target_class in [0., 1.]:
            self.target_class = target_class
        else:
            raise ValueError(f'`target_class` should be either `0` or `1`.')
        self.init_rgan()

    def init_rgan(self):
        gen_dims = self.enc_dims + self.exp_dims + [self.enc_dims[0]]
        self.generator = MultilayerPerception(gen_dims)
        self.discriminator = nn.Sequential(
            MultilayerPerception([gen_dims[0], 128]),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def model_forward(self, x):
        pass

    def discriminate(self, x):
        y_hat = self.discriminator(x)
        return torch.squeeze(y_hat, dim=-1)

    def forward(self, x, hard=False, imutable=True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """forward pass of CounteRGAN

        Args:
            x (torch.Tensor): input
            hard (bool, optional): categorical features in counterfactual is one-hot-encoding or not.
                Defaults to False.
            imutable (bool, optional): whether to use immutable features or not. Defaults to True.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: outputs `cf`, `real_fake_y`, `cf_y`
        """
        cf = self.generator(x)
        cf = x + cf
        cf = self.cat_normalize(cf, hard=hard)
        if imutable:
            cf[:, self.imutable_idx_list] = x[:, self.imutable_idx_list] * 1.0
        real_fake_y = self.discriminate(cf)
        cf_y = self.model(cf)

        return cf, real_fake_y, cf_y

    def generate_cf(self, x):
        cf, _, _ = self.forward(x, hard=True)
        return cf

    def discriminator_step(self, batch):
        x_real, _ = batch
        real_disc_y = self.discriminate(x_real)
        x_fake, fake_disc_y, _ = self(x_real)
        y_hat = torch.cat((real_disc_y, fake_disc_y))

        x = torch.cat((x_real, x_fake))
        y = torch.cat((torch.ones(len(x_real)), torch.zeros(len(x_fake))))

        # # shuffle
        # p = np.random.permutation(len(y))
        # x, y = x[p], y[p]

        # train model
        # y_hat = self.discriminator(x)
        loss = F.binary_cross_entropy(y_hat, y)
        return loss

    def generator_step(self, batch):
        x, y = batch
        cf, y_disc, cf_y = self.forward(x)
        # cf loss
        # y_prime = 1. - self.model.predict(x)
        y_prime = self.target_class + torch.zeros_like(cf_y)
        loss_cf = F.binary_cross_entropy(cf_y, y_prime)
        # gan loss
        y_fake = torch.ones(len(cf))
        loss_gan = F.binary_cross_entropy(y_disc, y_fake)
        # regularization loss
        reg_loss = 0. * F.l1_loss(x, cf) + 1e-6 * F.mse_loss(x, cf)

        return loss_gan + loss_cf + reg_loss

    def configure_optimizers(self):
        opt_1 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=0.0005)
        opt_2 = torch.optim.RMSprop([p for p in self.parameters() if p.requires_grad], lr=2e-4)
        return (opt_1, opt_2)

    def training_step(self, batch, batch_idx, optimizer_idx):
        self.model.freeze()
        # pl_logger.info([p for p in self.model.parameters() if not p.requires_grad])
        if batch_idx % 6 in [0, 1]:
            if optimizer_idx == 0:
                use_grad(self.discriminator, requires_grad=True)
                return self.discriminator_step(batch)
        else:
            if optimizer_idx == 1:
                use_grad(self.discriminator, requires_grad=False)
                return self.generator_step(batch)

    def validation_step(self, batch, batch_idx):
        loss = self.generator_step(batch)
        self.log('val/val_loss', loss)
        return loss

# Cell
class CounteRGAN:
    def __init__(
        self,
        rgan_0: CounteRGANTrainingModule,
        rgan_1:CounteRGANTrainingModule,
    ):
        # copy attributes
        self.__dict__ = rgan_0.__dict__.copy()
        self.rgan_0 = rgan_0
        self.rgan_1 = rgan_1
        self.model = self.rgan_0.model
        self.rgan_0.eval()
        self.rgan_1.eval()
        use_grad(self.rgan_0, self.rgan_1, requires_grad=False)

    def predict(self, x):
        return self.model.predict(x)

    def generate_cf(self, x):
        cf_0 = self.rgan_0.generate_cf(x)
        cf_1 = self.rgan_1.generate_cf(x)
        y_target = 1 - self.model.predict(x)
        return torch.where(
            torch.round(y_target).byte(), cf_1, cf_0
        )