# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03b_counterfactual_net.ipynb (unless otherwise specified).

__all__ = ['pl_logger', 'DataModule', 'BaselineTrainingModule', 'CounterfactualTrainingModule', 'LossWrapper',
           'CounterfactualTrainingModuleLossWrapper', 'CounterfactualTrainingModule2Optimizers',
           'CounterfactualTrainingModulePosthoc']

# Cell
from .import_essentials import *
from .utils import *
from pytorch_lightning.metrics.functional.classification import *
from sklearn.preprocessing import StandardScaler,MinMaxScaler, OneHotEncoder
from pytorch_lightning.callbacks import EarlyStopping

pl_logger = logging.getLogger('lightning')

# Cell
class DataModule(pl.LightningModule):
    """
    config[Dict]: containing configurations
    data_dir[str]: the location of the dataframe (assuming pandas dataframe)
    """
    def __init__(self, config: Dict):
        super().__init__()
        self.save_hyperparameters(config)

        # read data
        self.data = pd.read_csv(Path(config['data_dir']))
        self.continous_cols = config['continous_cols']
        self.discret_cols = config['discret_cols']
        self.imutable_cols = config['imutable_cols'] if 'imutable_cols' in config else []
        self.check_cols()

        # set configs
        self.lr = config['lr']
        self.batch_size = config['batch_size']
        self.lambda_1 = config['lambda_1'] if 'lambda_1' in config.keys() else 1
        self.lambda_2 = config['lambda_2'] if 'lambda_2' in config.keys() else 1
        self.lambda_3 = config['lambda_3'] if 'lambda_3' in config.keys() else 1
        self.threshold = config['threshold'] if 'threshold' in config.keys() else 0.5
        self.smooth_y = config['smooth_y'] if 'smooth_y' in config.keys() else True

        # loss functions
        self.loss_func_1 = get_loss_functions(config['loss_1']) if 'loss_1' in config.keys() else get_loss_functions("cross_entropy")
        self.loss_func_2 = get_loss_functions(config['loss_2']) if 'loss_2' in config.keys() else get_loss_functions("l1_mean")
        self.loss_func_3 = get_loss_functions(config['loss_3']) if 'loss_3' in config.keys() else get_loss_functions("cross_entropy")

        # self.optimizer_names = [optimizers(optim_name) for optim_name in config['optimizer_names']]

        # set model configs
        self.enc_dims = config['encoder_dims'] if 'encoder_dims' in config.keys() else []
        self.dec_dims = config['decoder_dims'] if 'decoder_dims' in config.keys() else []
        self.exp_dims = config['explainer_dims'] if 'explainer_dims' in config.keys() else []

        # log graph
        self.example_input_array = torch.randn((1, self.enc_dims[0]))

    def check_cols(self):
        self.data = self.data.astype({col: np.float for col in self.continous_cols})
        # check imutable cols
        cols = self.continous_cols + self.discret_cols
        for col in self.imutable_cols:
            assert col in cols

    def training_epoch_end(self, outs):
        if self.current_epoch == 0:
            self.logger.log_hyperparams(self.hparams)

    def transform(self, x, return_tensor=True):
        assert isinstance(x, pd.DataFrame)
        x_cont = self.normalizer.transform(x[self.continous_cols]) if self.continous_cols else np.array([[] for _ in range(len(x))])
        x_cat = self.encoder.transform(x[self.discret_cols]) if self.discret_cols else np.array([[] for _ in range(len(x))])
        x = np.concatenate((x_cont, x_cat), axis=1)
        return torch.from_numpy(x).float() if return_tensor else x

    def inverse_transform(self, x, return_tensor=True):
        """x should be a transformed tensor"""
        cat_idx = len(self.continous_cols)
        # inverse transform
        x_cont_inv = self.normalizer.inverse_transform(x[:, :cat_idx].cpu())
        x_cat_inv = self.encoder.inverse_transform(x[:, cat_idx:].cpu()) if self.discret_cols else np.array([[] for _ in range(len(x))])
        x = np.concatenate((x_cont_inv, x_cat_inv), axis=1)
        return torch.from_numpy(x).float() if return_tensor else x

    def predict(self, x):
        raise NotImplementedError

    def check_cont_robustness(self, x, c, c_y):
        cat_idx = len(self.continous_cols)
        # inverse transform
        x_cont_inv = self.normalizer.inverse_transform(x[:, :cat_idx].cpu())
        c_cont_inv = self.normalizer.inverse_transform(c[:, :cat_idx].cpu())
        # calculate the diff between x and c
        cont_diff = np.abs(x_cont_inv - c_cont_inv) < self.threshold
        # total nums of differences
        total_diffs = np.sum(cont_diff.any(axis=1))
        # new continous cf
        c_cont_hat = np.where(cont_diff, x_cont_inv, c_cont_inv)
        c[:, :cat_idx] = torch.from_numpy(self.normalizer.transform(c_cont_hat))
        c_y_hat = self.predict(c)
        return ((c_y_hat > .5) != (c_y > .5)).sum(), total_diffs

    def cat_normalize(self, c, hard=False):
        # categorical feature starting index
        cat_idx = len(self.continous_cols)
        return cat_normalize(c, self.cat_arrays, cat_idx, hard=hard)

    def prepare_data(self):
        def split_x_and_y(data):
            X = data[data.columns[:-1]]
            y = data[data.columns[-1]]
            return X, y

        def find_imutable_idx_list(
            cat_idx: int,
            imutable_col_names: List[str],
            discrete_col_names: List[str],
            cat_arrays: List[List[str]]
        ) -> List[int]:
            imutable_idx_list = []
            for i, (col_name, cols) in enumerate(zip(discrete_col_names, cat_arrays)):
                cat_end_idx = cat_idx + len(cols)
                if col_name in imutable_col_names:
                    imutable_idx_list += list(range(cat_idx, cat_end_idx))
                cat_idx = cat_end_idx
            return imutable_idx_list


        X, y = split_x_and_y(self.data)

        # preprocessing
        self.normalizer = MinMaxScaler()
        self.encoder = OneHotEncoder(sparse=False)
        X_cont = self.normalizer.fit_transform(X[self.continous_cols]) if self.continous_cols else np.array([[] for _ in range(len(X))])
        X_cat = self.encoder.fit_transform(X[self.discret_cols]) if self.discret_cols else np.array([[] for _ in range(len(X))])
        X = np.concatenate((X_cont, X_cat), axis=1)
        self.cat_arrays = self.encoder.categories_ if self.discret_cols else []
        # imutable
        self.imutable_idx_list = find_imutable_idx_list(
            cat_idx=len(self.continous_cols), imutable_col_names=self.imutable_cols, discrete_col_names=self.discret_cols, cat_arrays=self.cat_arrays
        )
        pl_logger.info(f"x_cont: {X_cont.shape}, x_cat: {X_cat.shape}")
        pl_logger.info(X.shape)
        assert X.shape[-1] == self.enc_dims[0], f'The input dimension X (shape: {X.shape[-1]})  != encoder_dims[0]: {self.enc_dims}'

        # prepare train & test
        train_X, test_X, train_y, test_y = train_test_split(X, y.to_numpy(), shuffle=False)
        self.train_dataset = NumpyDataset(train_X, train_y)
        self.val_dataset = NumpyDataset(test_X, test_y)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          pin_memory=True, shuffle=True, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          pin_memory=True, shuffle=True, num_workers=0)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          pin_memory=True, shuffle=False, num_workers=0)

# Cell
class BaselineTrainingModule(DataModule):
    def __init__(self, config: Dict):
        super().__init__(config)

    def model_forward(self, x):
        raise NotImplementedError

    def forward(self, *x):
        return self.model_forward(x)

    def predict(self, x):
        """x has not been preprocessed"""
        # x = self.transform(x)
        self.freeze()
        # pl_logger.info(f"x: {x}")
        y_hat = self(x)
        return torch.round(y_hat)

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def training_step(self, batch, batch_idx):
        # batch
        *x, y = batch
        # x = x.view(x.size(0), -1)
        # fwd
        y_hat = self(*x)
        # loss
        y = torch.where(y == 1,
                        uniform(y.size(), 0.8, 0.95, device=self.device),
                        uniform(y.size(), 0.05, 0.2, device=self.device))
        loss = F.binary_cross_entropy(y_hat, y)
        # Logging to TensorBoard by default
        self.log('train/train_loss_1', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        # log = {"train_loss": loss}

        return loss

    def validation_step(self, batch, batch_idx):
        # batch
        *x, y = batch
        # fwd
        y_hat = self(*x)
        # loss
        loss = F.binary_cross_entropy(y_hat, y)
        score = accuracy(y_hat > .5, y)
        return {'score': score, 'val_loss': loss}

    def validation_epoch_end(self, val_outs):
        avg_loss = torch.stack([output['val_loss'] for output in val_outs]).mean()
        avg_score = torch.stack([output['score'] for output in val_outs]).mean()
        self.log('val/val_loss', avg_loss)
        self.log('val/pred_accuracy', avg_score)

# Cell

class CounterfactualTrainingModule(DataModule):
    def __init__(self, config: Dict):
        super().__init__(config)

    def model_forward(self, x):
        raise NotImplementedError

    def forward(self, x, hard=False, imutable=True):
        """hard: categorical features in counterfactual is one-hot-encoding or not"""
        y, cf = self.model_forward(x)
        cf = self.cat_normalize(cf, hard=hard)
        if imutable:
            cf[:, self.imutable_idx_list] = x[:, self.imutable_idx_list] * 1.0
        return y, cf

    def predict(self, x):
        """x has not been preprocessed"""
        # x = self.transform(x)
        # self.freeze()
        # pl_logger.info(f"x: {x}")
        y_hat, c = self.model_forward(x)
        return torch.round(y_hat)

    def generate_cf(self, x, clamp=False, imutable=True):
        self.freeze()
        y, cf = self.model_forward(x)
        if imutable:
            cf[:, self.imutable_idx_list] = x[:, self.imutable_idx_list] * 1.0
        if clamp:
            cf = torch.clamp(cf, 0., 1.)
        return self.cat_normalize(cf, hard=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        return optimizer

    def _loss_functions(self, x, c, y, y_hat, y_prime=None, y_prime_mode='predicted', is_val=False):
        """
        x: input value
        c: conterfactual example
        y: ground truth
        y_hat: predicted result
        y_prime_mode: 'label' or 'predicted'
        """
        # flip zero/one
        if y_prime == None:
            if y_prime_mode == 'label':
                y_prime = torch.ones(y.shape) - y
            elif y_prime_mode == 'predicted':
                y_prime = (y_hat < .5).clone().detach().float()

        c_y, _ = self(c)
        # loss functions
        if self.smooth_y and not is_val:
            y = smooth_y(y)
            y_prime = smooth_y(y_prime)
        # l_1 = F.binary_cross_entropy(y_hat, y)
        # l_2 = F.l1_loss(c, x, reduction='mean') / x.abs().mean() # MAD
        # l_3 = F.binary_cross_entropy(c_y, y_prime)
        l_1 = self.loss_func_1(y_hat, y)
        l_2 = self.loss_func_2(x, c)
        l_3 = self.loss_func_3(c_y, y_prime)

        return l_1, l_2, l_3

    def _loss_compute(self, l_1, l_2, l_3):
        return self.lambda_1 * l_1 + self.lambda_2 * l_2 + self.lambda_3 * l_3

    def _logging_gradient(self):
        enc_grads = []
        pred_grads = []
        exp_grads = []
        for n, p in self.model.named_parameters():
            if (p.requires_grad) and ('bias' not in n):
                _grad = p.grad
                if ('encoder' in n) and (_grad is not None):
                    enc_grads.append(_grad)
                elif ('predictor' in n) and (_grad is not None):
                    pred_grads.append(_grad)
                elif ('explainer' in n) and (_grad is not None):
                    exp_grads.append(_grad)

        logger = self.logger.experiment
        if len(enc_grads) > 0:
            logger.add_histogram('gradient/encoder', torch.tensor(enc_grads), self.global_step, bins='auto')
        if len(pred_grads) > 0:
            logger.add_histogram('gradient/predictor', torch.tensor(pred_grads), self.global_step, bins='auto')
        if len(exp_grads) > 0:
            logger.add_histogram('gradient/explainer', torch.tensor(exp_grads), self.global_step, bins='auto')

    def _logging_loss(self, l_1, l_2, l_3, stage: str, on_step: bool = False):
        self.log(f'{stage}/{stage}_loss_1', l_1, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log(f'{stage}/{stage}_loss_2', l_2, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log(f'{stage}/{stage}_loss_3', l_3, on_step=on_step, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)

    def _logging_cf_results(self, x, c, y, y_hat, c_y):
        """
        params:
            x: input value
            c: conterfactual example
            y: ground truth
            y_hat: predicted result
            c_y: the prediction of counterfactual example
        """
        cat_idx = len(self.continous_cols)
        log = None
        if self.current_epoch % 10 == 0:# and self.current_epoch != 0:
            x = x.cpu()
            c = c.cpu()
            x_0_cont = self.normalizer.inverse_transform(x[0, :cat_idx].reshape(1, -1))
            c_0_cont = self.normalizer.inverse_transform(c[0, :cat_idx].reshape(1, -1))
            x_0_cat = self.encoder.inverse_transform(x[0, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []
            c_0_cat = self.encoder.inverse_transform(c[0, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []

            x_log = f"x_cont: {x_0_cont}, x_cat: {x_0_cat}, y_hat: {y_hat[0]}"
            c_log = f"c_cont: {c_0_cont}, c_cat: {c_0_cat}, y_ctf: {c_y[0]}"
            label_log = f"label: {y[0]}"

            log = f"""
            {"==" * 25}
            {label_log}
            {x_log}
            {c_log}
            {"==" * 25}
            """
        return log

    def transformed_cf_results(self, x, y):
        cat_idx = len(self.continous_cols)
        # y_hat, c = self(x, hard=True)
        c = self.generate_cf(x, clamp=True)

        log = ""
        x = x.cpu()
        c = c.cpu()

        sparsity = 0
        distance = 1000
        best_log = ""

        for i in range(len(x)):
            x_0_cont = self.normalizer.inverse_transform(x[i, :cat_idx].reshape(1, -1))
            c_0_cont = self.normalizer.inverse_transform(c[i, :cat_idx].reshape(1, -1))
            x_0_cat = self.encoder.inverse_transform(x[i, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []
            c_0_cat = self.encoder.inverse_transform(c[i, cat_idx:].unsqueeze(dim=0)) if self.discret_cols else []

            x_log = f"x_cont: {np.round(x_0_cont)}, x_cat: {x_0_cat}"
            c_log = f"c_cont: {np.round(c_0_cont)}, c_cat: {c_0_cat}"
            original_c = f"c: {c[i, :]}"

            cont_diff = np.abs(x_0_cont - c_0_cont) < 10.0
            # total nums of differences
            total_diffs = np.sum(cont_diff)

            log += f"""
            {"==" * 25}
            {x_log}
            {c_log}
            {original_c}
            {"==" * 25}
            """
            # if total_diffs > sparsity:
            #     best_log = f"""{"==" * 25}\n{x_log}\n{c_log}\n{original_c}\n{"==" * 25}"""
            if sum(abs(x[i, :] - c[i, :])) < distance:
                best_log = f"""{"==" * 25}\n{x_log}\n{c_log}\n{original_c}\n{"==" * 25}"""
        return log, best_log


    def training_step(self, batch, batch_idx):
        # batch
        x, y = batch
        # fwd
        y_hat, c = self(x)
        # pl_logger.info(f"y_hat: {y_hat.requires_grad}, c: {c.requires_grad}")
        # loss
        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)
        # pl_logger.info(f"l_1: {l_1.requires_grad}, l_2: {l_2.requires_grad}")
        # logging train loss
        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=True)

        return self._loss_compute(l_1, l_2, l_3)

#     def on_before_zero_grad(self, optimizer):
#         self._logging_gradient()

    def validation_step(self, batch, batch_idx):
        # batch
        x, y = batch
        # fwd
        y_hat, c = self(x, hard=True)
        c_y, _ = self(c)
        # loss
        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat, is_val=True)
        loss = l_1 + self.lambda_3 * l_3 + self.lambda_2 * l_2
        # logging val loss
        self._logging_loss(l_1, l_2, l_3, stage='val')

        # metrics
        cat_idx = len(self.continous_cols)

        pred_acc = accuracy(y_hat > .5, y)
        cf_proximity = torch.abs(x - c).sum(dim=1).mean()
        cf_acc = accuracy(c_y > .5, y_hat < .5)

        # print counterfactual results
        # log = self._logging_cf_results(x, c, y, y_hat, c_y)
        log = None

        # logging robustness on manipulating small
        diffs, total_num = self.check_cont_robustness(x, c, c_y)

        return {
                'pred_acc': pred_acc,
                'cf_proximity': cf_proximity,
                'cf_acc': cf_acc,
                'val_loss': loss,
                'log': log,
                'diffs': diffs,
                'total_num': total_num
               }

    def validation_epoch_end(self, val_outs):
        loss, pred_accuracy, cf_proximity, cf_accuracy, diffs, total_diff_num = (0. for _ in range(6))
        logs = []

        for out in val_outs:
            loss += out['val_loss']
            pred_accuracy += out['pred_acc']
            cf_proximity += out['cf_proximity']
            cf_accuracy += out['cf_acc']
            diffs += out['diffs']
            total_diff_num += out['total_num']
            if out['log'] is not None:
                logs.append(out['log'])

        size = len(val_outs)
        if total_diff_num == 0:
            total_diff_num = 1
#         avg_loss = torch.stack([output['val_loss'] for output in val_outs]).mean()
#         avg_pred_accuracy = torch.stack([output['pred_acc'] for output in val_outs]).mean()
#         avg_cf_proximity = torch.stack([output['cf_proximity'] for output in val_outs]).mean()
#         avg_cf_accuracy = torch.stack([output['cf_acc'] for output in val_outs]).mean()
#         avg_robust_accuracy = torch.stack([output['robustness'] for output in val_outs]).mean()
#         logs = [output['log'] for output in val_outs if output['log'] is not None]

        self.log('val/val_loss', loss / size, sync_dist=True)
        self.log('val/pred_accuracy', pred_accuracy / size, sync_dist=True)
        self.log('val/cf_proximity', cf_proximity / size, sync_dist=True)
        self.log('val/cf_accuracy', cf_accuracy / size, sync_dist=True)
        self.log('val/robustness', (1 - diffs / total_diff_num), sync_dist=True)
        self.log('val/total_diff_num', total_diff_num, sync_dist=True)
        self.logger.experiment.add_text('results','\n\n'.join(logs))

# Cell
class LossWrapper(pl.LightningModule):
    def __init__(self, loss_num=3):
        super().__init__()
        self.loss_num = loss_num
        self.log_vars = nn.Parameter(torch.zeros((loss_num)))

    def forward(self, *loss_f):
        assert self.loss_num == len(loss_f)

        loss = 0.
        for i, l in enumerate(loss_f):
            w = torch.exp(-self.log_vars[i])
            loss += torch.sum(w * l ** 2 + self.log_vars[0], -1)

        return loss.mean()

# Cell
class CounterfactualTrainingModuleLossWrapper(CounterfactualTrainingModule):
    def __init__(self, config, loss_wrapper=None):
        super().__init__(config)
        self.loss_wrapper = LossWrapper() if loss_wrapper is None else loss_wrapper

    def training_step(self, batch, batch_idx):
        # batch
        x, y = batch
        # fwd
        y_hat, c = self(x)
        # loss
        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)

        # Logging to TensorBoard by default
        self.log('train/train_loss_1', l_1, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log('train/train_loss_2', l_2, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log('train/train_loss_3', l_3, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        # log = {"train_loss": loss}

        return self.loss_wrapper(l_1, l_2, l_3)

# Cell

class CounterfactualTrainingModule2Optimizers(CounterfactualTrainingModule):
    def configure_optimizers(self):
        opt_1 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        opt_2 = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        return (opt_1, opt_2)

    def training_step(self, batch, batch_idx, optimizer_idx):
        # batch
        x, y = batch
        # fwd
        y_hat, c = self(x)
        # loss
        l_1, l_2, l_3 = self._loss_functions(x, c, y, y_hat)

        result = 0
        if optimizer_idx == 0:
            # use_grad(self, requires_grad=True)
            result = self.predictor_step(l_1, l_3)

        if optimizer_idx == 1:
            # freeze_modules = [self.encoder_model, self.predictor, self.pred_linear]
            # use_grad(*freeze_modules, requires_grad=False)
            result = self.explainer_step(l_2, l_3)

        # Logging to TensorBoard by default
        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=True)
        return result

    def predictor_step(self, l_1, l_3):
        p_loss = self.lambda_1 * l_1 #+ self.lambda_3 * l_3
        self.log('train/p_loss', p_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        return p_loss

    def explainer_step(self, l_2, l_3):
        e_loss = self.lambda_2 * l_2 + self.lambda_3 * l_3
        self.log('train/e_loss', e_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        return e_loss


class CounterfactualTrainingModulePosthoc(CounterfactualTrainingModule):
    def configure_optimizers(self):
        opt = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        return opt

    def training_step(self, batch, batch_idx):
        # batch
        x, y = batch
        # fwd
        y_hat, cf = self(x)
        # loss
        l_1, l_2, l_3 = self._loss_functions(x, cf, y, y_hat)

        if self.current_epoch < self.trainer.max_epochs // 2:
            use_grad(self, requires_grad=True)
            result = self.predictor_step(l_1, l_3)
        else:
            freeze_modules = [self.encoder_model, self.predictor, self.pred_linear]
            use_grad(*freeze_modules, requires_grad=False)
            result = self.explainer_step(l_2, l_3)

        # Logging to TensorBoard by default
        self._logging_loss(l_1, l_2, l_3, stage='train', on_step=True)
        return result

    def predictor_step(self, l_1, l_3):
        p_loss = self.lambda_1 * l_1 #+ self.lambda_3 * l_3
        self.log('train/p_loss', p_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        return p_loss

    def explainer_step(self, l_2, l_3):
        e_loss = self.lambda_2 * l_2 + self.lambda_3 * l_3
        self.log('train/e_loss', e_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        return e_loss