import logging
from copy import deepcopy
from typing import List, Union

import numpy as np
import ray
import torch
import torch.nn.functional as F
import torch.optim as optim
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule, Trainer
from ray import ray_constants, tune
from torch.utils.data import DataLoader, Dataset
from torch_ema import ExponentialMovingAverage

from src.data import RealDatasetCollection, SyntheticDatasetCollection
from src.models.utils import AlphaRise, BRTreatmentOutcomeHead, bce, grad_reverse
from src.models.utils_causal_cpc import ICLUB

logger = logging.getLogger(__name__)
ray_constants.FUNCTION_SIZE_ERROR_THRESHOLD = 10**8  # ~ 100Mb


def train_eval_factual(
    args: dict,
    train_f: Dataset,
    val_f: Dataset,
    orig_hparams: DictConfig,
    input_size: int,
    model_cls,
    tuning_criterion="rmse",
    **kwargs,
):
    """
    Globally defined method, used for ray tuning
    :param args: Hyperparameter configuration
    :param train_f: Factual train dataset
    :param val_f: Factual val dataset
    :param orig_hparams: DictConfig of original hyperparameters
    :param input_size: Input size of model, infuences concrete hyperparameter configuration
    :param model_cls: class of model
    :param kwargs: Other args
    """
    OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True)
    new_params = deepcopy(orig_hparams)
    model_cls.set_hparams(new_params.model, args, input_size, model_cls.model_type)
    if model_cls.model_type == "decoder":
        # Passing encoder takes too much memory
        encoder_r_size = (
            new_params.model.encoder.br_size
            if "br_size" in new_params.model.encoder
            else new_params.model.encoder.seq_hidden_units
        )  # Using either br_size or Memory adapter
        model = model_cls(new_params, encoder_r_size=encoder_r_size, **kwargs).double()
    else:
        model = model_cls(new_params, **kwargs).double()

    train_loader = DataLoader(
        train_f,
        shuffle=True,
        batch_size=new_params.model[model_cls.model_type]["batch_size"],
        drop_last=True,
    )
    trainer = Trainer(
        gpus=eval(str(new_params.exp.gpus))[:1],
        logger=None,
        max_epochs=new_params.exp.max_epochs,
        progress_bar_refresh_rate=0,
        gradient_clip_val=(
            new_params.model[model_cls.model_type]["max_grad_norm"]
            if "max_grad_norm" in new_params.model[model_cls.model_type]
            else None
        ),
        callbacks=[AlphaRise(rate=new_params.exp.alpha_rate)],
    )
    trainer.fit(model, train_dataloader=train_loader)

    if tuning_criterion == "rmse":
        val_rmse_orig, val_rmse_all = model.get_normalised_masked_rmse(val_f)
        tune.report(val_rmse_orig=val_rmse_orig, val_rmse_all=val_rmse_all)
    elif tuning_criterion == "bce":
        val_bce_orig, val_bce_all = model.get_masked_bce(val_f)
        tune.report(val_bce_orig=val_bce_orig, val_bce_all=val_bce_all)
    else:
        raise NotImplementedError()


class TimeVaryingCausalModel(LightningModule):
    """
    Abstract class for models, estimating counterfactual outcomes over time
    """

    model_type = None  # Will be defined in subclasses
    possible_model_types = None  # Will be defined in subclasses
    tuning_criterion = None

    def __init__(
        self,
        args: DictConfig,
        dataset_collection: Union[RealDatasetCollection, SyntheticDatasetCollection] = None,
        autoregressive: bool = None,
        has_vitals: bool = None,
        bce_weights: np.array = None,
        **kwargs,
    ):
        """
        Args:
            args: DictConfig of model hyperparameters
            dataset_collection: Dataset collection
            autoregressive: Flag of including previous outcomes to modelling
            has_vitals: Flag of vitals in dataset
            bce_weights: Re-weight BCE if used
            **kwargs: Other arguments
        """
        super().__init__()
        self.dataset_collection = dataset_collection
        if dataset_collection is not None:
            self.autoregressive = self.dataset_collection.autoregressive
            self.has_vitals = self.dataset_collection.has_vitals
            self.bce_weights = None  # Will be calculated, when calling preparing data
        else:
            self.autoregressive = autoregressive
            self.has_vitals = has_vitals
            self.bce_weights = bce_weights

        # General datasets parameters
        self.dim_treatments = args.model.dim_treatments
        self.dim_vitals = args.model.dim_vitals
        self.dim_static_features = args.model.dim_static_features
        self.dim_outcome = args.model.dim_outcomes

        self.input_size = None  # Will be defined in subclasses

        self.save_hyperparameters(args)  # Will be logged to mlflow

    def _get_optimizer(self, param_optimizer: list):
        no_decay = ["bias", "layer_norm"]
        sub_args = self.hparams.model[self.model_type]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": sub_args["optimizer"]["weight_decay"],
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        lr = sub_args["optimizer"]["learning_rate"]
        optimizer_cls = sub_args["optimizer"]["optimizer_cls"]
        if optimizer_cls.lower() == "adamw":
            optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == "adam":
            optimizer = optim.Adam(optimizer_grouped_parameters, lr=lr)
        elif optimizer_cls.lower() == "sgd":
            optimizer = optim.SGD(
                optimizer_grouped_parameters, lr=lr, momentum=sub_args["optimizer"]["momentum"]
            )
        else:
            raise NotImplementedError()

        return optimizer

    def _get_lr_schedulers(self, optimizer):
        if not isinstance(optimizer, list):
            lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
            return [optimizer], [lr_scheduler]
        else:
            lr_schedulers = []
            for opt in optimizer:
                lr_schedulers.append(optim.lr_scheduler.ExponentialLR(opt, gamma=0.99))
            return optimizer, lr_schedulers

    def configure_optimizers(self):
        optimizer = self._get_optimizer(list(self.named_parameters()))
        if self.hparams.model[self.model_type]["optimizer"]["lr_scheduler"]:
            return self._get_lr_schedulers(optimizer)
        return optimizer

    def train_dataloader(self) -> DataLoader:
        sub_args = self.hparams.model[self.model_type]
        return DataLoader(
            self.dataset_collection.train_f,
            shuffle=True,
            batch_size=sub_args["batch_size"],
            drop_last=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset_collection.val_f, batch_size=self.hparams.dataset.val_batch_size
        )

    def get_predictions(self, dataset: Dataset) -> np.array:
        raise NotImplementedError()

    def get_propensity_scores(self, dataset: Dataset) -> np.array:
        raise NotImplementedError()

    def get_representations(self, dataset: Dataset) -> np.array:
        raise NotImplementedError()

    def get_autoregressive_predictions(self, dataset: Dataset) -> np.array:
        logger.info(f"Autoregressive Prediction for {dataset.subset_name}.")
        if self.model_type == "decoder":  # CRNDecoder / EDCTDecoder / RMSN Decoder

            predicted_outputs = np.zeros(
                (len(dataset), self.hparams.dataset.projection_horizon, self.dim_outcome)
            )
            for t in range(self.hparams.dataset.projection_horizon):
                logger.info(f"t = {t + 2}")

                outputs_scaled = self.get_predictions(dataset)
                predicted_outputs[:, t] = outputs_scaled[:, t]

                if t < (self.hparams.dataset.projection_horizon - 1):
                    dataset.data["prev_outputs"][:, t + 1, :] = outputs_scaled[:, t, :]
        else:
            raise NotImplementedError()

        return predicted_outputs

    def get_masked_bce(self, dataset: Dataset):
        logger.info(f"BCE calculation for {dataset.subset_name}.")
        treatment_pred = torch.tensor(self.get_propensity_scores(dataset))
        current_treatments = torch.tensor(dataset.data["current_treatments"])

        bce = (
            (self.bce_loss(treatment_pred, current_treatments, kind="predict"))
            .unsqueeze(-1)
            .numpy()
        )
        bce = bce * dataset.data["active_entries"]

        # Calculation like in original paper (Masked-Averaging over datapoints (& outputs) and then non-masked time axis)
        bce_orig = bce.sum(0).sum(-1) / dataset.data["active_entries"].sum(0).sum(-1)
        bce_orig = bce_orig.mean()

        # Masked averaging over all dimensions at once
        bce_all = bce.sum() / dataset.data["active_entries"].sum()

        return bce_orig, bce_all

    def get_normalised_masked_rmse(self, dataset: Dataset, one_step_counterfactual=False):
        logger.info(f"RMSE calculation for {dataset.subset_name}.")
        outputs_scaled = self.get_predictions(dataset)
        unscale = self.hparams.exp.unscale_rmse
        percentage = self.hparams.exp.percentage_rmse

        if unscale:
            output_stds, output_means = (
                dataset.scaling_params["output_stds"],
                dataset.scaling_params["output_means"],
            )
            outputs_unscaled = outputs_scaled * output_stds + output_means

            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_unscaled - dataset.data["unscaled_outputs"]) ** 2) * dataset.data[
                "active_entries"
            ]
        else:
            # Batch-wise masked-MSE calculation is tricky, thus calculating for full dataset at once
            mse = ((outputs_scaled - dataset.data["outputs"]) ** 2) * dataset.data[
                "active_entries"
            ]

        # Calculation like in original paper (Masked-Averaging over datapoints (& outputs) and then non-masked time axis)
        mse_orig = mse.sum(0).sum(-1) / dataset.data["active_entries"].sum(0).sum(-1)
        mse_orig = mse_orig.mean()
        rmse_normalised_orig = np.sqrt(mse_orig) / dataset.norm_const

        # Masked averaging over all dimensions at once
        mse_all = mse.sum() / dataset.data["active_entries"].sum()
        rmse_normalised_all = np.sqrt(mse_all) / dataset.norm_const

        if percentage:
            rmse_normalised_orig *= 100.0
            rmse_normalised_all *= 100.0

        if one_step_counterfactual:
            # Only considering last active entry with actual counterfactuals
            num_samples, time_dim, output_dim = dataset.data["active_entries"].shape
            last_entries = dataset.data["active_entries"] - np.concatenate(
                [dataset.data["active_entries"][:, 1:, :], np.zeros((num_samples, 1, output_dim))],
                axis=1,
            )
            if unscale:
                mse_last = (
                    (outputs_unscaled - dataset.data["unscaled_outputs"]) ** 2
                ) * last_entries
            else:
                mse_last = ((outputs_scaled - dataset.data["outputs"]) ** 2) * last_entries

            mse_last = mse_last.sum() / last_entries.sum()
            rmse_normalised_last = np.sqrt(mse_last) / dataset.norm_const

            if percentage:
                rmse_normalised_last *= 100.0

            return rmse_normalised_orig, rmse_normalised_all, rmse_normalised_last

        return rmse_normalised_orig, rmse_normalised_all

    def get_pehe_one_step(self, dataset: Dataset):
        logger.info(f"PEHE calculation for {dataset.subset_name}.")
        percentage = self.hparams.exp.percentage_rmse

        output_stds, output_means = (
            dataset.scaling_params["output_stds"],
            dataset.scaling_params["output_means"],
        )

        num_samples, time_dim, output_dim = dataset.data["active_entries"].shape
        last_entries = dataset.data["active_entries"] - np.concatenate(
            [dataset.data["active_entries"][:, 1:, :], np.zeros((num_samples, 1, output_dim))],
            axis=1,
        )

        dataset.data["current_treatments"][:, -1, :] = np.ones((num_samples, 1))
        yt_1 = self.get_predictions(dataset)
        yt_1 = yt_1 * output_stds + output_means

        dataset.data["current_treatments"][:, -1, :] = np.zeros((num_samples, 1))
        yt_0 = self.get_predictions(dataset)
        yt_0 = yt_0 * output_stds + output_means

        ites_pred = yt_1 - yt_0
        ite_real = dataset.data["ITE"]

        if len(ite_real.shape) == 2:
            ite_real = ite_real[:, :, np.newaxis]

        pehe_last = ((ites_pred - ite_real) ** 2) * last_entries

        pehe_last = pehe_last.sum() / last_entries.sum()
        rmse_normalised_last = np.sqrt(pehe_last) / dataset.norm_const

        if percentage:
            rmse_normalised_last *= 100.0

        return rmse_normalised_last

    def get_normalised_n_step_rmses(self, dataset: Dataset, datasets_mc: List[Dataset] = None):
        logger.info(f"RMSE calculation for {dataset.subset_name}.")
        assert (
            self.model_type == "decoder"
            or self.model_type == "multi"
            or self.model_type == "g_net"
            or self.model_type == "msm_regressor"
            or self.model_type == "cdvae"
        )

        unscale = self.hparams.exp.unscale_rmse
        percentage = self.hparams.exp.percentage_rmse
        outputs_scaled = self.get_autoregressive_predictions(
            dataset if datasets_mc is None else datasets_mc
        )

        if unscale:
            output_stds, output_means = (
                dataset.scaling_params["output_stds"],
                dataset.scaling_params["output_means"],
            )
            outputs_unscaled = outputs_scaled * output_stds + output_means

            mse = (
                (outputs_unscaled - dataset.data_processed_seq["unscaled_outputs"]) ** 2
            ) * dataset.data_processed_seq["active_entries"]
        else:
            mse = (
                (outputs_scaled - dataset.data_processed_seq["outputs"]) ** 2
            ) * dataset.data_processed_seq["active_entries"]

        nan_idx = np.unique(np.where(np.isnan(dataset.data_processed_seq["outputs"]))[0])
        not_nan = np.array([i for i in range(outputs_scaled.shape[0]) if i not in nan_idx])

        # Calculation like in original paper (Masked-Averaging over datapoints (& outputs) and then non-masked time axis)
        mse_orig = mse[not_nan].sum(0).sum(-1) / dataset.data_processed_seq["active_entries"][
            not_nan
        ].sum(0).sum(-1)
        rmses_normalised_orig = np.sqrt(mse_orig) / dataset.norm_const

        if percentage:
            rmses_normalised_orig *= 100.0

        return rmses_normalised_orig

    @staticmethod
    def set_hparams(model_args: DictConfig, new_args: dict, input_size: int, model_type: str):
        raise NotImplementedError()

    def finetune(self, resources_per_trial: dict):
        """
        Hyperparameter tuning with ray[tune]
        """
        self.prepare_data()
        sub_args = self.hparams.model[self.model_type]
        logger.info(f"Running hyperparameters selection with {sub_args['tune_range']} trials")
        ray.init(
            num_gpus=len(eval(str(self.hparams.exp.gpus))),
            num_cpus=4,
            _redis_max_memory=ray_constants.FUNCTION_SIZE_ERROR_THRESHOLD,
        )

        hparams_grid = {k: tune.choice(v) for k, v in sub_args["hparams_grid"].items()}
        analysis = tune.run(
            tune.with_parameters(
                train_eval_factual,
                input_size=self.input_size,
                model_cls=self.__class__,
                tuning_criterion=self.tuning_criterion,
                train_f=deepcopy(self.dataset_collection.train_f),
                val_f=deepcopy(self.dataset_collection.val_f),
                orig_hparams=self.hparams,
                autoregressive=self.autoregressive,
                has_vitals=self.has_vitals,
                bce_weights=self.bce_weights,
                projection_horizon=(
                    self.projection_horizon if hasattr(self, "projection_horizon") else None
                ),
            ),
            resources_per_trial=resources_per_trial,
            metric=f"val_{self.tuning_criterion}_all",
            mode="min",
            config=hparams_grid,
            num_samples=sub_args["tune_range"],
            name=f"{self.__class__.__name__}{self.model_type}",
            max_failures=3,
        )
        ray.shutdown()

        logger.info(f"Best hyperparameters found: {analysis.best_config}.")
        logger.info("Resetting current hyperparameters to best values.")
        self.set_hparams(
            self.hparams.model, analysis.best_config, self.input_size, self.model_type
        )

        self.__init__(
            self.hparams,
            dataset_collection=self.dataset_collection,
            encoder=self.encoder if hasattr(self, "encoder") else None,
            propensity_treatment=(
                self.propensity_treatment if hasattr(self, "propensity_treatment") else None
            ),
            propensity_history=(
                self.propensity_history if hasattr(self, "propensity_history") else None
            ),
        )
        return self

    def visualize(self, dataset: Dataset, index=0, artifacts_path=None):
        pass

    def bce_loss(self, treatment_pred, current_treatments, kind="predict", label_smoothing=0):
        loss_iclub = ICLUB()
        mode = self.hparams.dataset.treatment_mode
        bce_weights = (
            torch.tensor(self.bce_weights).type_as(current_treatments)
            if self.hparams.exp.bce_weight
            else None
        )

        if kind == "predict":
            bce_loss = bce(treatment_pred, current_treatments, mode, bce_weights, label_smoothing)
        elif kind == "confuse":
            uniform_treatments = torch.ones_like(current_treatments)
            if mode == "multiclass":
                uniform_treatments *= 1 / current_treatments.shape[-1]
            elif mode == "multilabel":
                uniform_treatments *= 0.5
            bce_loss = bce(treatment_pred, uniform_treatments, mode)

        elif kind == "MI":  # for causal cpc
            loss_iclub = ICLUB()
            bce_loss = loss_iclub(treatment_pred, current_treatments, mode)
        else:
            raise NotImplementedError()
        return bce_loss

    def on_fit_start(self) -> None:  # Issue with logging not yet existing parameters in MlFlow
        if self.trainer.logger is not None:
            self.trainer.logger.filter_submodels = list(
                self.possible_model_types - {self.model_type}
            )

    def on_fit_end(self) -> None:  # Issue with logging not yet existing parameters in MlFlow
        if self.trainer.logger is not None:
            self.trainer.logger.filter_submodels = list(self.possible_model_types)

    def get_clusters_RE(self, data_loader):
        all_re_labels = []

        for batch in data_loader:
            re_labels = batch["re_labels"]
            all_re_labels.append(re_labels)

        all_re_labels = torch.cat(all_re_labels, dim=0)

        return all_re_labels.numpy()


class BRCausalModel(TimeVaryingCausalModel):
    """
    Abstract class for models, estimating counterfactual outcomes over time with balanced representations
    """

    model_type = None  # Will be defined in subclasses
    possible_model_types = None  # Will be defined in subclasses
    tuning_criterion = "rmse"

    def __init__(
        self,
        args: DictConfig,
        dataset_collection: Union[RealDatasetCollection, SyntheticDatasetCollection] = None,
        autoregressive: bool = None,
        has_vitals: bool = None,
        bce_weights: np.array = None,
        **kwargs,
    ):
        """
        Args:
            args: DictConfig of model hyperparameters
            dataset_collection: Dataset collection
            autoregressive: Flag of including previous outcomes to modelling
            has_vitals: Flag of vitals in dataset
            bce_weights: Re-weight BCE if used
            **kwargs: Other arguments
        """
        super().__init__(args, dataset_collection, autoregressive, has_vitals, bce_weights)

        # Balancing representation training parameters
        self.balancing = args.exp.balancing
        self.alpha = args.exp.alpha  # Used for gradient-reversal
        self.update_alpha = args.exp.update_alpha

    def configure_optimizers(self):
        if self.balancing == "grad_reverse" and not self.hparams.exp.weights_ema:  # one optimizer
            optimizer = self._get_optimizer(list(self.named_parameters()))

            if self.hparams.model[self.model_type]["optimizer"]["lr_scheduler"]:
                return self._get_lr_schedulers(optimizer)

            return optimizer

        else:  # two optimizers - simultaneous gradient descent update
            treatment_head_params = [
                "br_treatment_outcome_head." + s
                for s in self.br_treatment_outcome_head.treatment_head_params
            ]
            treatment_head_params = [
                k
                for k in dict(self.named_parameters())
                for param in treatment_head_params
                if k.startswith(param)
            ]
            non_treatment_head_params = [
                k for k in dict(self.named_parameters()) if k not in treatment_head_params
            ]

            assert len(treatment_head_params + non_treatment_head_params) == len(
                list(self.named_parameters())
            )

            treatment_head_params = [
                (k, v)
                for k, v in dict(self.named_parameters()).items()
                if k in treatment_head_params
            ]
            non_treatment_head_params = [
                (k, v)
                for k, v in dict(self.named_parameters()).items()
                if k in non_treatment_head_params
            ]

            if self.hparams.exp.weights_ema:
                self.ema_treatment = ExponentialMovingAverage(
                    [par[1] for par in treatment_head_params], decay=self.hparams.exp.beta
                )
                self.ema_non_treatment = ExponentialMovingAverage(
                    [par[1] for par in non_treatment_head_params], decay=self.hparams.exp.beta
                )

            treatment_head_optimizer = self._get_optimizer(treatment_head_params)
            non_treatment_head_optimizer = self._get_optimizer(non_treatment_head_params)

            if self.hparams.model[self.model_type]["optimizer"]["lr_scheduler"]:
                return self._get_lr_schedulers(
                    [non_treatment_head_optimizer, treatment_head_optimizer]
                )

            return [non_treatment_head_optimizer, treatment_head_optimizer]

    def optimizer_step(
        self,
        epoch: int = None,
        batch_idx: int = None,
        optimizer=None,
        optimizer_idx: int = None,
        *args,
        **kwargs,
    ) -> None:
        super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs)
        if self.hparams.exp.weights_ema and optimizer_idx == 0:
            self.ema_non_treatment.update()
        elif self.hparams.exp.weights_ema and optimizer_idx == 1:
            self.ema_treatment.update()

    def _calculate_bce_weights(self) -> None:
        if self.hparams.dataset.treatment_mode == "multiclass":
            current_treatments = self.dataset_collection.train_f.data["current_treatments"]
            current_treatments = current_treatments.reshape(-1, current_treatments.shape[-1])
            current_treatments = current_treatments[
                self.dataset_collection.train_f.data["active_entries"].flatten().astype(bool)
            ]
            current_treatments = np.argmax(current_treatments, axis=1)

            self.bce_weights = (
                len(current_treatments)
                / np.bincount(current_treatments)
                / len(np.bincount(current_treatments))
            )
        else:
            raise NotImplementedError()

    def on_fit_start(self) -> None:  # Issue with logging not yet existing parameters in MlFlow
        if self.trainer.logger is not None:
            self.trainer.logger.filter_submodels = (
                ["decoder"] if self.model_type == "encoder" else ["encoder"]
            )

    def on_fit_end(self) -> None:  # Issue with logging not yet existing parameters in MlFlow
        if self.trainer.logger is not None:
            self.trainer.logger.filter_submodels = ["encoder", "decoder"]

    def training_step(self, batch, batch_ind, optimizer_idx=0):
        for par in self.parameters():
            par.requires_grad = True

        if optimizer_idx == 0:  # grad reversal or domain confusion representation update
            if self.hparams.exp.weights_ema:
                with self.ema_treatment.average_parameters():
                    treatment_pred, outcome_pred, _ = self(batch)
            else:
                treatment_pred, outcome_pred, _ = self(batch)

            mse_loss = F.mse_loss(outcome_pred, batch["outputs"], reduce=False)
            if self.balancing == "grad_reverse":
                bce_loss = self.bce_loss(
                    treatment_pred, batch["current_treatments"].double(), kind="predict"
                )
            elif self.balancing == "domain_confusion":
                bce_loss = self.bce_loss(
                    treatment_pred, batch["current_treatments"].double(), kind="confuse"
                )
                bce_loss = self.br_treatment_outcome_head.alpha * bce_loss
            else:
                raise NotImplementedError()

            # Masking for shorter sequences
            # Attention! Averaging across all the active entries (= sequence masks) for full batch
            bce_loss = (batch["active_entries"].squeeze(-1) * bce_loss).sum() / batch[
                "active_entries"
            ].sum()
            mse_loss = (batch["active_entries"] * mse_loss).sum() / batch["active_entries"].sum()

            loss = bce_loss + mse_loss

            self.log(
                f"{self.model_type}_train/loss", loss, on_epoch=True, on_step=False, sync_dist=True
            )
            self.log(
                f"{self.model_type}_train/bce_loss",
                bce_loss,
                on_epoch=True,
                on_step=False,
                sync_dist=True,
            )
            self.log(
                f"{self.model_type}_train/mse_loss",
                mse_loss,
                on_epoch=True,
                on_step=False,
                sync_dist=True,
            )
            self.log(
                f"{self.model_type}_alpha",
                self.br_treatment_outcome_head.alpha,
                on_epoch=True,
                on_step=False,
                sync_dist=True,
            )

            return loss

        elif optimizer_idx == 1:  # domain classifier update
            if self.hparams.exp.weights_ema:
                with self.ema_non_treatment.average_parameters():
                    treatment_pred, _, _ = self(batch, detach_treatment=True)
            else:
                treatment_pred, _, _ = self(batch, detach_treatment=True)

            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="predict"
            )
            if self.balancing == "domain_confusion":
                bce_loss = self.br_treatment_outcome_head.alpha * bce_loss

            # Masking for shorter sequences
            # Attention! Averaging across all the active entries (= sequence masks) for full batch
            bce_loss = (batch["active_entries"].squeeze(-1) * bce_loss).sum() / batch[
                "active_entries"
            ].sum()
            self.log(
                f"{self.model_type}_train/bce_loss_cl",
                bce_loss,
                on_epoch=True,
                on_step=False,
                sync_dist=True,
            )

            return bce_loss

    def test_step(self, batch, batch_ind, **kwargs):
        if self.hparams.exp.weights_ema:
            with self.ema_non_treatment.average_parameters():
                with self.ema_treatment.average_parameters():
                    treatment_pred, outcome_pred, _ = self(batch)
        else:
            treatment_pred, outcome_pred, _ = self(batch)

        if self.balancing == "grad_reverse":
            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="predict"
            )
        elif self.balancing == "domain_confusion":
            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="confuse"
            )
        elif self.balancing == "mutual_info":
            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="MI"
            )

        mse_loss = F.mse_loss(outcome_pred, batch["outputs"], reduce=False)

        # Masking for shorter sequences
        # Attention! Averaging across all the active entries (= sequence masks) for full batch
        bce_loss = (batch["active_entries"].squeeze(-1) * bce_loss).sum() / batch[
            "active_entries"
        ].sum()
        mse_loss = (batch["active_entries"] * mse_loss).sum() / batch["active_entries"].sum()
        loss = bce_loss + mse_loss

        subset_name = self.test_dataloader().dataset.subset_name
        self.log(
            f"{self.model_type}_{subset_name}/loss",
            loss,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            f"{self.model_type}_{subset_name}/bce_loss",
            bce_loss,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            f"{self.model_type}_{subset_name}/mse_loss",
            mse_loss,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )

    def validation_step(self, batch, batch_ind, **kwargs):
        if self.hparams.exp.weights_ema:
            with self.ema_non_treatment.average_parameters():
                with self.ema_treatment.average_parameters():
                    treatment_pred, outcome_pred, _ = self(batch)
        else:
            treatment_pred, outcome_pred, _ = self(batch)

        if self.balancing == "grad_reverse":
            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="predict"
            )
        elif self.balancing == "domain_confusion":
            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="confuse"
            )

        elif self.balancing == "mutual_info":
            bce_loss = self.bce_loss(
                treatment_pred, batch["current_treatments"].double(), kind="MI"
            )

        mse_loss = F.mse_loss(outcome_pred, batch["outputs"], reduce=False)

        # Masking for shorter sequences
        # Attention! Averaging across all the active entries (= sequence masks) for full batch
        bce_loss = (batch["active_entries"].squeeze(-1) * bce_loss).sum() / batch[
            "active_entries"
        ].sum()
        mse_loss = (batch["active_entries"] * mse_loss).sum() / batch["active_entries"].sum()
        loss = bce_loss + mse_loss

        subset_name = self.val_dataloader().dataset.subset_name
        self.log(
            f"{self.model_type}_{subset_name}/loss",
            loss,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            f"{self.model_type}_{subset_name}/bce_loss",
            bce_loss,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log("val/loss", mse_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True)

    def predict_step(self, batch, batch_idx, dataset_idx=None):
        """
        Generates normalised output predictions
        """
        if self.hparams.exp.weights_ema:
            with self.ema_non_treatment.average_parameters():
                _, outcome_pred, br = self(batch)
        else:
            _, outcome_pred, br = self(batch)
        return outcome_pred.cpu(), br.cpu()

    def get_representations(self, data) -> np.array:
        if not isinstance(data, DataLoader):
            logger.info(f"Balanced representations inference for {data.subset_name}.")
            data_loader = DataLoader(
                data, batch_size=self.hparams.dataset.val_batch_size, shuffle=False
            )
        else:
            data_loader = data

        _, br = (torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader)))
        return br.numpy()

    def get_predictions(self, dataset: Dataset) -> np.array:
        logger.info(f"Predictions for {dataset.subset_name}.")
        # Creating Dataloader
        data_loader = DataLoader(
            dataset, batch_size=self.hparams.dataset.val_batch_size, shuffle=False
        )
        outcome_pred, _ = (
            torch.cat(arrs) for arrs in zip(*self.trainer.predict(self, data_loader))
        )
        return outcome_pred.numpy()
