import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

from ..encoder import gaussian_encode
from .base import BaseLCM
from util import inverse_softplus
from util import mask, clean_and_clamp, correlation
from transforms import make_mlp_structure_transform
from nets import make_mlp
from .normflow import Planar


class SoftILCM(BaseLCM):
    """
    Top-level class for generative models with
    - an SCM with a learned or fixed causal graph
    - separate encoder and decoder (i.e. a VAE) outputting noise encodings
    - VI over intervention targets
    """

    def __init__(
            self,
            causal_model,
            encoder,
            noise_encoder,
            noise_decoder,
            decoder,
            dim_z,
            num_components,
            action_classifier=None,
            object_classifier=None,
            intervention_set="atomic_or_none",
            averaging_strategy="stochastic",
            adversarial=True,
            noise_model='multiply',
            init_std=1.0,
            num_samples=10,
            min_std=1.0e-3,
    ):
        super().__init__(
            causal_model,
            encoder,
            decoder=decoder,
            dim_z=dim_z,
            intervention_set=intervention_set,
            action_classifier=action_classifier,
            object_classifier=object_classifier,
        )

        self.averaging_strategy = averaging_strategy
        self.num_components = num_components
        self.dim_z = dim_z
        self._min_std = min_std
        self.adversarial = adversarial
        self.noise_model = noise_model

        if adversarial:
            self.mine_model = MINE()

        self.num_samples = num_samples
        self.bce_criterion = nn.BCELoss()
        self.noise_encoder = noise_encoder
        self.noise_decoder = noise_decoder

        self.create_auxiliary_nets()

    def forward(
            self,
            x1,
            x2,
            s1=None,
            s2=None,
            true_action=None,
            true_object=None,
            beta=1.0,
            beta_intervention_target=None,
            pretrain_beta=None,
            full_likelihood=True,
            likelihood_reduction="sum",
            graph_mode="hard",
            graph_temperature=1.0,
            graph_samples=1,
            pretrain=False,
            model_interventions=True,
            model_noise=False,
            deterministic_intervention_encoder=False,
            intervention_encoder_offset=1.0e-4,
            **kwargs,
    ):
        """
        Evaluates an observed data pair.

        Arguments:
        ----------
        x1 : torch.Tensor of shape (batchsize, DIM_X,), dtype torch.float
            Observed data point (before the intervention)
        x2 : torch.Tensor of shape (batchsize, DIM_X,), dtype torch.float
            Observed data point (after the intervention)
        interventions : None or torch.Tensor of shape (batchsize, DIM_Z,), dtype torch.float
            If not None, specifies the interventions

        Returns:
        --------
        log_prob : torch.Tensor of shape (batchsize, 1), dtype torch.float
            If `interventions` is not None: Conditional log likelihood
            `log p(x1, x2 | interventions)`.
            If `interventions` is None: Marginal log likelihood `log p(x1, x2)`.
        outputs : dict with str keys and torch.Tensor values
            Detailed breakdown of the model outputs and internals.
        """

        batchsize = x1.shape[0]
        feature_dims = list(range(1, len(x1.shape)))
        assert torch.all(torch.isfinite(x1)) and torch.all(torch.isfinite(x2))

        # Pretraining
        if pretrain:
            return self.forward_pretrain(
                x1,
                x2,
                s1,
                s2,
                beta=pretrain_beta,
                full_likelihood=full_likelihood,
                likelihood_reduction=likelihood_reduction,
            )

        # Get noise encoding means and stds
        e1_mean, e1_std = self.encoder.mean_std(x1, s1)
        e2_mean, e2_std = self.encoder.mean_std(x2, s2)
        noise_mean, noise_std = self.noise_encoder.mean_std(x2 - x1, s2)

        # Regularization terms
        e_norm, consistency_mse, _ = self._compute_latent_reg_consistency_mse(e1_mean, e1_std, e2_mean, e2_std,
                                                                              noise_mean, noise_std,
                                                                              feature_dims, x1, x2,
                                                                              beta=beta)

        outputs = {"noise_std": noise_std}

        # Sample from e1, e2 given intervention (including the projection to the counterfactual manifold)
        e1_proj, e2_proj, log_posterior_eps1_proj, log_posterior_eps2_proj = self._project_and_sample(e1_mean, e1_std,
                                                                                                      e2_mean, e2_std,
                                                                                                      true_action)
        noise_proj, log_posterior_noise = self._project_and_sample_noise(noise_mean, noise_std)

        # Compute ELBO terms
        (
            log_likelihood,
            log_likelihood_noise,
            log_posterior_eps,
            log_posterior_noise,
            log_prior_eps,
            log_prior_noise,
            mse,
            inverse_consistency_mse,
            mi_loss_e2_e1,
            mi_loss_n1_e1,
            mi_loss_n1_e2,
            action_loss,
            object_loss,
            outputs_,
        ) = self._compute_elbo_terms(
            x1,
            x2,
            e1_proj,
            e2_proj,
            noise_proj,
            feature_dims,
            full_likelihood,
            true_action,
            true_object,
            likelihood_reduction,
            log_posterior_eps1_proj,
            log_posterior_eps2_proj,
            log_posterior_noise,
            model_noise,
            model_interventions,
        )

        # Some more bookkeeping
        for key, val in outputs_.items():
            val = val.unsqueeze(1)
            if key in outputs:
                outputs[key] = torch.cat((outputs[key], val), dim=1)
            else:
                outputs[key] = val

        loss = self._compute_outputs(
            beta,
            true_action,
            consistency_mse,
            e1_std,
            e2_std,
            e_norm,
            true_action,
            log_likelihood,
            log_likelihood_noise,
            log_posterior_eps,
            log_posterior_noise,
            log_prior_eps,
            log_prior_noise,
            mse,
            outputs,
            inverse_consistency_mse,
            mi_loss_e2_e1,
            mi_loss_n1_e1,
            mi_loss_n1_e2,
            action_loss,
            object_loss
        )

        return loss, outputs

    def forward_pretrain(self, x1, x2, s1, s2, beta, full_likelihood=False,
                         likelihood_reduction="sum"):
        assert torch.all(torch.isfinite(x1)) and torch.all(torch.isfinite(x2))
        feature_dims = list(range(1, len(x1.shape)))

        # Get noise encoding means and stds
        e1_mean, e1_std = self.encoder.mean_std(x1, s1)
        e2_mean, e2_std = self.encoder.mean_std(x2, s2)
        noise1_mean, noise1_std = self.noise_encoder.mean_std(x2 - x1, s1)

        encoder_std = noise1_std

        # Regularization terms
        e_norm, consistency_mse, beta_vae_loss = self._compute_latent_reg_consistency_mse(
            e1_mean,
            e1_std,
            e2_mean,
            e2_std,
            noise1_mean,
            noise1_std,
            feature_dims,
            x1,
            x2,
            beta=beta,
            full_likelihood=full_likelihood,
            likelihood_reduction=likelihood_reduction,
        )

        # Pretraining loss
        outputs = dict(
            z_regularization=e_norm, consistency_mse=consistency_mse, encoder_std=encoder_std
        )

        return beta_vae_loss, outputs

    def encode_to_noise(self, x, deterministic=True):
        """
        Given data x, returns the noise encoding.

        Arguments:
        ----------
        x : torch.Tensor of shape (batchsize, DIM_X), dtype torch.float
            Data point to be encoded.
        deterministic : bool, optional
            If True, enforces deterministic encoding (e.g. by not adding noise in a Gaussian VAE).

        Returns:
        --------
        e : torch.Tensor of shape (batchsize, DIM_Z), dtype torch.float
            Noise encoding phi_e(x)
        """

        # e, _ = self.encoder(x, deterministic=deterministic)
        mean, std = self.encoder.mean_std(x)

        e, _ = gaussian_encode(mean, std, deterministic=deterministic)

        return e

    def encode_to_causal(self, x, x2, deterministic=True):
        """
        Given data x, returns the causal encoding.

        Arguments:
        ----------
        x : torch.Tensor of shape (batchsize, DIM_X), dtype torch.float
            Data point to be encoded.
        deterministic : bool, optional
            If True, enforces deterministic encoding (e.g. by not adding noise in a Gaussian VAE).

        Returns:
        --------
        inputs : torch.Tensor of shape (batchsize, DIM_Z), dtype torch.float
            Causal-variable encoding phi_z(x)
        """

        e_mean, e_std = self.encoder.mean_std(x)
        e, _ = gaussian_encode(e_mean, e_std, deterministic=False)
        adjacency_matrix = self._sample_adjacency_matrices(mode="deterministic", n=x.shape[0])
        z = self.scm.noise_to_causal(e, adjacency_matrix=adjacency_matrix)

        return z

    def latents_correlations(self, x, x2):

        e_mean, e_std = self.encoder.mean_std(x)
        e_mean_t, e_std_t = self.encoder.mean_std(x2)
        noise_mean, noise_std = self.noise_encoder.mean_std(x2 - x)
        e, _ = gaussian_encode(e_mean, e_std, deterministic=False)
        et, _ = gaussian_encode(e_mean_t, e_std_t, deterministic=False)
        noise, _ = gaussian_encode(noise_mean, noise_std, deterministic=False)

        _, shifts, scales, _ = self.scm._solve(e, e, adjacency_matrix=None)

        shift_e_corr = np.zeros((self.dim_z, self.dim_z))
        scales_e_corr = np.zeros((self.dim_z, self.dim_z))
        e_e_corr = np.ones((self.dim_z, self.dim_z))
        n_e_corr = np.ones((self.dim_z, self.dim_z))
        for i in range(self.dim_z):
            gn = self.g[i](noise)

            # Calculate the mean of x and y
            for j in range(self.dim_z):
                if i == j:
                    continue
                shift_e_corr[i, j] = correlation(shifts[:, i], e[:, j]).item()
                scales_e_corr[i, j] = correlation(scales[:, i], e[:, j]).item()
                e_e_corr[i, j] = correlation(e[:, i], e[:, j]).item()

            for j in range(self.dim_z):
                n_e_corr[i, j] = correlation(noise[:, i], e[:, j]).item()

        return e_e_corr, n_e_corr, shift_e_corr, scales_e_corr

    def decode_noise(self, e, deterministic=True):
        """
        Given noise encoding e, returns data x.

        Arguments:
        ----------
        e : torch.Tensor of shape (batchsize, DIM_Z), dtype torch.float
            Noise-encoded data.
        deterministic : bool, optional
            If True, enforces deterministic decoding (e.g. by not adding noise in a Gaussian VAE).

        Returns:
        --------
        x : torch.Tensor of shape (batchsize, DIM_X), dtype torch.float
            Decoded data point.
        """

        x, _ = self.decoder(e, deterministic=deterministic)
        return x

    def decode_causal(self, z, deterministic=True):
        """
        Given causal latents inputs, returns data x.

        Arguments:
        ----------
        inputs : torch.Tensor of shape (batchsize, DIM_Z), dtype torch.float
            Causal latent variables.
        deterministic : bool, optional
            If True, enforces deterministic decoding (e.g. by not adding noise in a Gaussian VAE).

        Returns:
        --------
        x : torch.Tensor of shape (batchsize, DIM_X), dtype torch.float
            Decoded data point.
        """

        adjacency_matrix = self._sample_adjacency_matrices(mode="deterministic", n=z.shape[0])
        e = self.scm.causal_to_noise(z, adjacency_matrix=adjacency_matrix)
        x, _ = self.decoder(e, deterministic=deterministic)
        return x

    def log_likelihood(self, x1, x2, interventions=None, n_latent_samples=20, **kwargs):
        """
        Computes estimate of the log likelihood using importance weighting, like in IWAE.

        `log p(x) = log E_{inputs ~ q(inputs|x)} [ p(x|inputs) p(inputs) / q(inputs|x) ]`
        """

        # Copy each sample n_latent_samples times
        x1_ = self._expand(x1, repeats=n_latent_samples)
        x2_ = self._expand(x2, repeats=n_latent_samples)
        interventions_ = self._expand(interventions, repeats=n_latent_samples)

        # Evaluate ELBO
        negative_elbo, _ = self.forward(x1_, x2_, interventions_, beta=1.0)

        # Compute importance-weighted estimate of log likelihood
        log_likelihood = self._contract(-negative_elbo, mode="logmeanexp", repeats=n_latent_samples)

        return log_likelihood

    def encode_decode(self, x, deterministic=True):
        """Auto-encode data and return reconstruction"""
        eps = self.encode_to_noise(x, deterministic=deterministic)
        x_reco = self.decode_noise(eps, deterministic=deterministic)

        return x_reco

    def predict_classes(self, x1, x2, s1=None, s2=None,):
        """Auto-encode data and return reconstruction"""

        e1_mean, e1_std = self.encoder.mean_std(x1, s1)
        e2_mean, e2_std = self.encoder.mean_std(x2, s2)

        e1, _ = gaussian_encode(e1_mean, e1_std, deterministic=False)
        e2, _ = gaussian_encode(e2_mean, e2_std, deterministic=False)

        z1, _, _, _ = self.scm._solve(e1, e1, None)
        z2, _, _, _ = self.scm._solve(e2, e1, None)
        action_logits = self.action_classifier(torch.cat((z2, z1), 1))
        object_logits = self.object_classifier(torch.cat((z2, z1), 1))

        action_prob = torch.softmax(action_logits, 1)
        object_prob = torch.softmax(object_logits, 1)

        return action_prob, object_prob

    def encode_decode_pair(self, x1, x2, intervention, deterministic=True):
        """Auto-encode data pair and return latent representation and reconstructions"""

        # Get noise encoding means and stds
        e1_mean, e1_std = self.encoder.mean_std(x1)
        e2_mean, e2_std = self.encoder.mean_std(x2)

        # Compute intervention posterior
        one_hot_interventions = torch.zeros_like(e1_mean)
        one_hot_interventions[torch.arange(len(intervention)), intervention] = 1.0

        # Project to manifold
        e1_proj, e2_proj, log_posterior1_proj, log_posterior2_proj = self._project_and_sample(
            e1_mean, e1_std, e2_mean, e2_std, intervention, deterministic=deterministic
        )

        # Project back to data space
        x1_reco = self.decode_noise(e1_proj)
        x2_reco = self.decode_noise(e2_proj)

        return (
            x1_reco,
            x2_reco,
            e1_mean,
            e2_mean,
            e1_proj,
            e2_proj,
            one_hot_interventions,
        )

    def infer_intervention(
            self,
            x1,
            x2,
            deterministic=True,
    ):
        """Given data pair, infer intervention"""

        (
            x1_reco,
            x2_reco,
            e1_mean,
            e2_mean,
            e1_proj,
            e2_proj,
            intervention_posterior,
            most_likely_intervention_idx,
            intervention,
        ) = self.encode_decode_pair(x1, x2, deterministic=deterministic)

        return most_likely_intervention_idx, None, x2_reco

    def mutual_information(self, a, b, interventions):

        mi_loss = 0
        unique_intervention = torch.unique(interventions)
        for i in unique_intervention:
            indices = (interventions == i).nonzero().squeeze()
            for j in range(b.shape[1]):
                if j == i:
                    continue
                joint_A = a[indices, i]
                joint_B = b[indices, j]

                try:
                    marginal_idx = indices[torch.randperm(len(indices))]
                except TypeError:
                    continue

                marginal_A = a[marginal_idx, i]
                marginal_B = b[indices, j]  # Use the original (unshuffled) batch for marginal B

                joint_A = joint_A.unsqueeze(1)
                joint_B = joint_B.unsqueeze(1)
                marginal_A = marginal_A.unsqueeze(1)
                marginal_B = marginal_B.unsqueeze(1)

                joint_preds = self.mine_model(joint_A, joint_B)
                marginal_preds = self.mine_model(marginal_A, marginal_B)
                mi_loss += mine_loss(joint_preds, marginal_preds)

        return mi_loss

    def _project_and_sample_noise(self, noise_mean, noise_std):

        noise, log_posterior_proj = gaussian_encode(noise_mean, noise_std, reduction="sum")

        return noise, log_posterior_proj

    def _project_and_sample(self, e1_mean, e1_std, e2_mean, e2_std, intervention, deterministic=False):

        one_hot_interventions = torch.zeros_like(e1_mean)
        one_hot_interventions[torch.arange(len(intervention)), intervention] = 1.0

        # Project to manifold
        (
            e1_mean_proj,
            e1_std_proj,
            e2_mean_proj,
            e2_std_proj,
        ) = self._project_to_manifold(one_hot_interventions, e1_mean, e1_std, e2_mean, e2_std)

        # Sample noise
        e1_proj, log_posterior1_proj = gaussian_encode(
            e1_mean_proj, e1_std_proj, self.num_samples, deterministic=deterministic
        )
        e2_proj, log_posterior2_proj = gaussian_encode(
            e2_mean_proj, e2_std_proj, self.num_samples, deterministic=deterministic, reduction="none"
        )

        # Sampling should preserve counterfactual consistency
        e2_proj = one_hot_interventions * e2_proj + (1.0 - one_hot_interventions) * e1_proj
        log_posterior2_proj = torch.sum(log_posterior2_proj * one_hot_interventions, dim=1, keepdim=True)

        return e1_proj, e2_proj, log_posterior1_proj, log_posterior2_proj

    def _project_to_manifold(self, intervention, e1_mean, e1_std, e2_mean, e2_std):
        if self.averaging_strategy == "z2":
            lam = torch.ones_like(e1_mean)
        elif self.averaging_strategy in ["average", "mean"]:
            lam = 0.5 * torch.ones_like(e1_mean)
        elif self.averaging_strategy == "stochastic":
            lam = torch.rand_like(e1_mean)
        else:
            raise ValueError(f"Unknown averaging strategy {self.averaging_strategy}")

        projection_mean = lam * e1_mean + (1.0 - lam) * e2_mean
        projection_std = lam * e1_std + (1.0 - lam) * e2_std

        e1_mean = intervention * e1_mean + (1.0 - intervention) * projection_mean
        e1_std = intervention * e1_std + (1.0 - intervention) * projection_std
        e2_mean = intervention * e2_mean + (1.0 - intervention) * projection_mean
        e2_std = intervention * e2_std + (1.0 - intervention) * projection_std

        return e1_mean, e1_std, e2_mean, e2_std

    def _compute_latent_reg_consistency_mse(
            self,
            e1_mean,
            e1_std,
            e2_mean,
            e2_std,
            noise_mean,
            noise_std,
            feature_dims,
            x1,
            x2,
            beta,
            full_likelihood=False,
            likelihood_reduction="sum",
    ):
        e1, log_posterior1 = gaussian_encode(e1_mean, e1_std, self.num_samples, deterministic=False)
        e2, log_posterior2 = gaussian_encode(e2_mean, e2_std, self.num_samples, deterministic=False)
        noise, log_posterior_noise = gaussian_encode(noise_mean, noise_std, self.num_samples, deterministic=False)

        # Compute latent regularizer (useful early in training)
        e_norm = torch.sum(e1 ** 2, 1, keepdim=True) + torch.sum(e2 ** 2, 1, keepdim=True)

        # Compute consistency MSE
        consistency_x1_reco, log_likelihood1 = self.decoder(
            e1,
            eval_likelihood_at=x1,
            deterministic=True,
            full=full_likelihood,
            reduction=likelihood_reduction,
        )

        consistency_x2_reco, log_likelihood2 = self.decoder(
            e2,
            eval_likelihood_at=x2,
            deterministic=True,
            full=full_likelihood,
            reduction=likelihood_reduction,
        )

        consistency_mse = torch.sum((consistency_x1_reco - x1) ** 2, feature_dims).unsqueeze(1)
        consistency_mse += torch.sum((consistency_x2_reco - x2) ** 2, feature_dims).unsqueeze(1)

        # Compute prior and beta-VAE loss (for pre-training)
        log_prior1 = torch.sum(
            self.scm.base_density.log_prob(e1.reshape((-1, 1))).reshape((-1, self.dim_z)),
            dim=1,
            keepdim=True,
        )
        log_prior2 = torch.sum(
            self.scm.base_density.log_prob(e2.reshape((-1, 1))).reshape((-1, self.dim_z)),
            dim=1,
            keepdim=True,
        )

        log_prior_noise = torch.sum(
            self.scm.base_density.log_prob(noise.reshape((-1, 1))).reshape((-1, self.dim_z)),
            dim=1,
            keepdim=True,
        )

        beta_vae_loss = (
                -log_likelihood1
                - log_likelihood2
                + beta * (
                        log_posterior1 + log_posterior2 + log_posterior_noise - log_prior1 - log_prior2 - log_prior_noise)
        )

        return e_norm, consistency_mse, beta_vae_loss

    def _compute_outputs(
            self,
            beta,
            action_labels,
            consistency_mse,
            e1_std,
            e2_std,
            e_norm,
            interventions,
            log_likelihood,
            log_likelihood_noise,
            log_posterior_eps,
            log_posterior_noise,
            log_prior_eps,
            log_prior_noise,
            mse,
            outputs,
            inverse_consistency_mse,
            mi_loss_e2_e1,
            mi_loss_n1_e1,
            mi_loss_n1_e2,
            action_loss,
            object_loss
    ):
        # Put together to compute the ELBO and beta-VAE loss
        kl_eps = log_posterior_eps - log_prior_eps
        log_posterior = log_posterior_eps
        log_prior = log_prior_eps
        elbo = log_likelihood - kl_eps
        beta_vae_loss = - log_likelihood + beta * kl_eps

        if log_prior_noise is not None:
            log_posterior += log_posterior_noise
            log_prior += log_prior_noise
            kl_noise = log_posterior_noise - log_prior_noise
            elbo += log_likelihood_noise - kl_noise
            beta_vae_loss += - log_likelihood_noise + beta * kl_noise

            outputs["kl_noise"] = kl_noise

        if action_loss is not None:
            beta_vae_loss += action_loss
            beta_vae_loss += object_loss
            outputs["action_ce"] = action_loss
            outputs["object_ce"] = object_loss

        # Track individual components
        outputs["elbo"] = elbo
        outputs["beta_vae_loss"] = beta_vae_loss
        outputs["kl_epsilon"] = kl_eps
        outputs["log_likelihood"] = log_likelihood
        outputs["log_posterior"] = log_posterior
        outputs["log_prior"] = log_prior
        outputs["interventions"] = interventions
        outputs["mse"] = mse
        outputs["consistency_mse"] = consistency_mse
        outputs["inverse_consistency_mse"] = inverse_consistency_mse
        outputs["z_regularization"] = e_norm
        outputs["encoder_std"] = 0.5 * torch.mean(e1_std + e2_std, dim=1, keepdim=True)

        if mi_loss_n1_e1 is not None:
            outputs["mi_loss_n1_e1"] = mi_loss_n1_e1
        if mi_loss_n1_e2 is not None:
            outputs["mi_loss_n1_e2"] = mi_loss_n1_e2
        if mi_loss_e2_e1 is not None:
            outputs["mi_loss_e2_e1"] = mi_loss_e2_e1

        return beta_vae_loss

    def _compute_elbo_terms(
            self,
            x1,
            x2,
            e1_proj,
            e2_proj,
            noise_proj,
            feature_dims,
            full_likelihood,
            interventions,
            true_objects,
            likelihood_reduction,
            log_posterior_eps1_proj,
            log_posterior_eps2_proj,
            log_posterior_noise,
            model_noise,
            model_interventions,
    ):
        # Compute posterior q(e1, e2_I | I)
        log_posterior_eps_proj = log_posterior_eps1_proj + log_posterior_eps2_proj

        assert torch.all(torch.isfinite(log_posterior_eps_proj)) and torch.all(torch.isfinite(log_posterior_noise))

        x1_reco_proj, log_likelihood1_proj = self.decoder(
            e1_proj,
            eval_likelihood_at=x1,
            deterministic=True,
            full=full_likelihood,
            reduction=likelihood_reduction,
        )

        x2_reco_proj, log_likelihood2_proj = self.decoder(
            e2_proj,
            eval_likelihood_at=x2,
            deterministic=True,
            full=full_likelihood,
            reduction=likelihood_reduction,
        )
        x2_reco_noise, log_likelihood_noise = self.noise_decoder(
            noise_proj,
            eval_likelihood_at=x2 - x1,
            deterministic=True,
            full=full_likelihood,
            reduction=likelihood_reduction,
        )

        log_likelihood_proj = log_likelihood1_proj + log_likelihood2_proj
        assert torch.all(torch.isfinite(log_likelihood_proj))

        # Compute MSE
        mse_proj = torch.sum((x1_reco_proj - x1) ** 2, feature_dims).unsqueeze(1)
        mse_proj += torch.sum((x2_reco_proj - x2) ** 2, feature_dims).unsqueeze(1)
        mse_proj += torch.sum((x2_reco_noise - (x2 - x1)) ** 2, feature_dims).unsqueeze(1)

        # Compute inverse consistency MSE: |z - encode(decode(z))|^2
        e1_reencoded = self.encode_to_noise(x1_reco_proj, deterministic=False)
        e2_reencoded = self.encode_to_noise(x2_reco_proj, deterministic=False)
        inverse_consistency_mse_proj = torch.sum((e1_reencoded - e1_proj) ** 2, 1, keepdim=True)
        inverse_consistency_mse_proj += torch.sum((e2_reencoded - e2_proj) ** 2, 1, keepdim=True)

        # Compute prior p(e1, e2 | I)
        log_prior_eps, outputs = self.scm.log_prob_noise_weakly_supervised(
            e1_proj,
            e2_proj,
            interventions,
            adjacency_matrix=None,
            include_intervened=model_interventions,
            include_nonintervened=False,
        )

        outputs = {
            "noise_proj": noise_proj,
            "e1_proj": e1_proj,
            "e2_proj": e2_proj,
        }

        if model_noise:

            log_prior_noise1, log_prior_eps1, log_prior_eps2, zi, latent_outputs = self.compute_priors(noise_proj,
                                                                                                       e1_proj, e2_proj,
                                                                                                       interventions)
            log_prior_noise = log_prior_noise1
            outputs.update(latent_outputs)
            log_prior_eps = log_prior_eps1 + log_prior_eps2

        else:
            log_prior_noise = None
            log_posterior_noise = None

        if self.adversarial and model_noise:
            mi_loss_n1_e1 = self.compute_adversarial_term(noise_proj, e1_proj)
            # mi_loss_n1_e2 = self.compute_adversarial_term(e2_proj, noise1_proj)
            # mi_loss_e2_e1 = self.compute_adversarial_term(e2_proj, e1_proj)
            mi_loss_n1_e2 = None
            mi_loss_e2_e1 = None
        else:
            mi_loss_n1_e1 = None
            mi_loss_n1_e2 = None
            mi_loss_e2_e1 = None

        if hasattr(self, 'action_classifier'):
            z1, _, _, _ = self.scm._solve(e1_proj, e1_proj, None)
            z2, _, _, _ = self.scm._solve(e2_proj, e1_proj, None)
            logits_action = self.action_classifier(torch.cat((z2, z1), 1))
            logits_objects = self.object_classifier(torch.cat((z2, z1), 1))
            action_loss = F.cross_entropy(logits_action, interventions, reduction='none').unsqueeze(1)
            object_loss = F.cross_entropy(logits_objects, true_objects, reduction='none').unsqueeze(1)
        else:
            action_loss = None
            object_loss = None

        return (
            log_likelihood_proj,
            log_likelihood_noise,
            log_posterior_eps_proj,
            log_posterior_noise,
            log_prior_eps,
            log_prior_noise,
            mse_proj,
            inverse_consistency_mse_proj,
            mi_loss_e2_e1,
            mi_loss_n1_e1,
            mi_loss_n1_e2,
            action_loss,
            object_loss,
            outputs,
        )

    def compute_adversarial_term(self, latent1, latent2):

        mi_loss = 0
        for i in range(latent1.shape[1]):
            for j in range(latent2.shape[1]):

                joint_A = latent1[:, i].unsqueeze(1)
                joint_B = latent2[:, j].unsqueeze(1)

                marginal_A = shuffle_with_no_fixed_points(latent1[:, i].unsqueeze(1))
                marginal_B = latent2[:, j].unsqueeze(1)  # Use the original (unshuffled) batch for marginal B

                real_labels = torch.ones(joint_A.shape[0], 1).to(joint_A.device)
                fake_labels = torch.zeros(joint_A.shape[0], 1).to(joint_A.device)

                array_mi_loss = []
                for k in range(self.num_samples):
                    outputs_real = self.mine_model(joint_A[:, :, k], joint_B[:, :, k])
                    outputs_fake = self.mine_model(marginal_A[:, :, k], marginal_B[:, :, k])
                    loss_real = self.bce_criterion(outputs_real, real_labels)
                    loss_fake = self.bce_criterion(outputs_fake, fake_labels)
                    mi_loss_samples = (loss_real + loss_fake) / 2
                    array_mi_loss.append(mi_loss_samples.unsqueeze(0))

                mi_loss += torch.cat(array_mi_loss, 0).mean()

        return mi_loss

    def get_model_adversarial_loss(self, x1, x2, interventions):

        e1_mean, e1_std = self.encoder.mean_std(x1)
        e2_mean, e2_std = self.encoder.mean_std(x2)
        noise1_mean, noise1_std = self.noise_encoder.mean_std(x2 - x1)

        noise1, _ = self._project_and_sample_noise(noise1_mean, noise1_std, interventions)
        e1_proj, e2_proj, _, _ = self._project_and_sample(e1_mean, e1_std, e2_mean, e2_std, interventions)

        _, _, _, zi, _ = self.compute_priors(noise1, e1_proj, e2_proj, interventions)

        loss_adv = 0
        for i in range(zi.shape[1]):
            for j in range(e1_proj.shape[1]):
                if j == i:
                    continue

                marginal_A = shuffle_with_no_fixed_points(noise1[:, i])
                marginal_B = e1_proj[:, j]  # Use the original (unshuffled) batch for marginal B
                marginal_A = marginal_A.unsqueeze(1)
                marginal_B = marginal_B.unsqueeze(1)

                outputs_fake_enc = self.mine_model(marginal_A[:, :, 0], marginal_B[:, :, 0])
                loss_adv += -1 * torch.mean(
                    torch.log(outputs_fake_enc + 1e-8))  # - for MI minimization and + for MI maximization

        return loss_adv

    def compute_priors(self, noise1_proj, e1_proj, e2_proj, interventions):

        if self.noise_model == "additive":

            all_log = []
            all_zi = []
            all_gn = []
            all_shifts = []
            all_scales = []
            for i in range(self.dim_z):

                gn = self.g[i](noise1_proj)
                _, shifts, scales, log_det_scale = self.scm._solve(e2_proj, e1_proj, adjacency_matrix=None)
                zi = (e2_proj[:, i].unsqueeze(1) - shifts[:, i].unsqueeze(1) - gn) / scales[:, i].unsqueeze(1)

                mean = torch.tensor(0 + 0.5 * 1 * (e1_proj[:, i].unsqueeze(1) - 0.0)).to(e2_proj.device)
                log_prior_eps2 = torch.distributions.Normal(mean, 1).log_prob(zi.reshape((-1, 1))).reshape((-1, 1)) + log_det_scale[:, i].unsqueeze(1)


                all_log.append(log_prior_eps2)
                all_gn.append(gn)
                all_zi.append(zi)
                all_shifts.append(shifts)
                all_scales.append(scales)

            log_prior_eps2 = torch.cat(all_log, 1)[torch.arange(len(interventions)), interventions].unsqueeze(1)
            gn = torch.cat(all_gn, 1)[torch.arange(len(interventions)), interventions].unsqueeze(1)
            zi = torch.cat(all_zi, 1)
            shifts = torch.cat(all_shifts, 1)
            scales = torch.cat(all_scales, 1)

            log_prior_eps1 = torch.sum(
                self.scm.base_density.log_prob(e1_proj.reshape((-1, 1))).reshape((-1, self.dim_z)),
                dim=1,
                keepdim=True,
            )

            log_prior_noise1 = torch.sum(
                self.scm.base_density.log_prob(noise1_proj.reshape((-1, 1))).reshape((-1, self.dim_z)),
                dim=1,
                keepdim=True,
            )

            latent_outputs = {"gn": gn, "shifts": shifts, "scales": scales}

        else:
            raise NotImplementedError

        return log_prior_noise1, log_prior_eps1, log_prior_eps2, zi, latent_outputs

    def create_auxiliary_nets(self):

        if self.noise_model == "additive":

            features = (
                    [self.dim_z]
                    + [64 for _ in range(2)]
                    + [1]
            )
            g = []
            for _ in range(self.dim_z):
                # Effect of noise on e_i as scale and shift
                g_param_net = make_mlp(features)
                g.append(g_param_net)

            self.g = torch.nn.ModuleList(g)

        else:
            raise NotImplementedError

    def get_masked_context(self, i, epsilon):
        """Masks the input to a solution function to conform to topological order"""
        ancestor_mask = torch.ones((epsilon.shape[0], self.dim_z), device=epsilon.device)
        ancestor_mask[..., i] = 0.0

        masked_epsilon = mask(epsilon, ancestor_mask, concat_mask=True)

        return masked_epsilon

    def _init_layers(self, param_net, initialization):
        if initialization == "default":
            # param_net outputs mean and log std parameters of a Gaussian (log std only if
            # homoskedastic = False), as a function of the causal parents.
            # We usually want to initialize param_net such that:
            #  - log std is very close to zero
            #  - mean is reasonably close to zero, but may already have some nontrivial dependence on
            #    the parents
            mean_bias_std = 1.0e-3
            mean_weight_std = 0.1
            log_std_bias_std = 1.0e-6
            log_std_weight_std = 1.0e-3
            log_std_bias_mean = 0.0
        elif initialization == "strong_effects":
            # However, when creating a GT model as an initialized neural SCM, we want slightly more
            # interesting initializations, with pronounced causal effects. That's what the
            # enhance_causal_effects keyword is for. When that's True, we would like the Gaussian mean
            # to depend quite strongly on the parents, and also would appreciate some non-trivial
            # heteroskedasticity (log std depending on the parents).
            mean_bias_std = 0.2
            mean_weight_std = 1.5
            log_std_bias_std = 1.0e-6
            log_std_weight_std = 0.1
            log_std_bias_mean = 0.0
        elif initialization == "broad":
            # For noise-centric models we want that the typical initial standard deviation in p(e2 | e1)
            # is large, around 10
            mean_bias_std = 1.0e-3
            mean_weight_std = 0.1
            log_std_bias_std = 1.0e-6
            log_std_weight_std = 1.0e-3
            log_std_bias_mean = 2.3
        else:
            raise ValueError(f"Unknown initialization scheme {initialization}")

        last_layer = list(param_net._modules.values())[-1]
        nn.init.normal_(last_layer.bias[0], mean=log_std_bias_mean, std=log_std_bias_std)
        nn.init.normal_(last_layer.weight[0, :], mean=0.0, std=log_std_weight_std)
        nn.init.normal_(last_layer.bias[1], mean=0.0, std=mean_bias_std)
        nn.init.normal_(last_layer.weight[1, :], mean=0.0, std=mean_weight_std)

    def load_state_dict(self, state_dict, topological_order=None, dummy_values=None, strict=True):
        """Overloading the state dict loading so we can compute ancestor structure"""
        super().load_state_dict(state_dict, strict)
        self.scm._compute_ancestors()


# Define the MINE network
class MINE(nn.Module):
    def __init__(self):
        super(MINE, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(2, 10),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(10, 10),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(10, 10),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(10, 1),
            nn.Sigmoid()
        )

        self.apply(self._weights_init)

    def _weights_init(self, m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.zeros_(m.bias.data)

    def forward(self, z1, z2):
        return self.fc(torch.cat([z1, z2], dim=1))


def mine_loss(joint_preds, marginal_preds):
    return - (torch.mean(joint_preds) - torch.log(
        torch.mean(torch.exp(marginal_preds))))  # This maximizes the MI between variables


def shuffle_with_no_fixed_points(tensor):
    original_indices = torch.arange(tensor.size(0))
    shuffled_indices = torch.randperm(tensor.size(0))

    while (original_indices == shuffled_indices).any():
        shuffled_indices = torch.randperm(tensor.size(0))

    return tensor[shuffled_indices]