import math
import random
import sys

import behavioral_cloning.features as bc_features
import gin
import numpy as np
import torch
import torch.nn.functional as F

from causal_discovery.intervention_strategies import (
    get_strategy_from_name_and_graph_fitting_object,
)
from causal_discovery.sampling import sample_interventionalSamples
from causal_discovery.utils import evaluate_likelihoods_for_model

sys.path.append("../")

from causal_graphs.variable_distributions import LeafCategDist, _random_categ

import causal_discovery.logger as lg
from causal_discovery.datasets import InterventionalDataset


@gin.configurable
class GraphFitting(object):
    def __init__(
        self,
        parent_enco_object,
        model,
        graph,
        num_batches,
        num_graphs,
        theta_only_num_graphs,
        batch_size,
        lambda_sparse,
        sample_size_inters=None,
        max_graph_stacking=200,
        exclude_inters=None,
        force_online_data=False,
        interventions_policy="round_robin",
        interventions_check_max=-1,
        log_grads=False,
        theta_coeff=1.0,
        gamma_coeff=1.0,
        policy_softmax_temperature=1.0,
        use_global_theta_gradients=False,
        mix_round_robin_every_n_epochs=-1,
        mix_round_robin_from_epoch=-1,
        num_hypothetical_graphs=10,
        dataset_to_save=None,
        trained_model_path=None,
        features_fn=None,
        BALD_num_int_samples=128,
        BALD_batch_size=10,
        only_gamma=False,
        sdi_dag_regularization=0.5,
        sdi_like_interventions=False,
        dataset=None,
    ):
        """
        Creates a DistributionFitting object that summarizes all functionalities
        for performing the graph fitting stage of ENCO.

        Parameters
        ----------
        parent_enco_object : ENCO object that uses the graph fitter
        model : MultivarMLP
                PyTorch module of the neural networks that model the conditional
                distributions.
        graph : CausalDAG
                Causal graph on which we want to perform causal structure learning.
        num_batches : int
                      Number of batches to use per MC sample in the graph fitting stage.
                      Usually 1, only higher needed if GPU is running out of memory for
                      common batch sizes.
        num_graphs : int
                     Number of graph samples to use for estimating the gradients in the
                     graph fitting stage. Usually in the range 20-100.
        theta_only_num_graphs : int
                                Number of graph samples to use in the graph fitting stage if
                                gamma is frozen. Needs to be an even number, and usually 2 or 4.
        batch_size : int
                     Size of the batches to use in the gradient estimators.
        lambda_sparse : float
                        Sparsity regularizer value to use in the graph fitting stage.
        sample_size_inters: Number of samples to use per intervention. If an exported graph is
                            given as input and sample_size_inters is smaller than the exported
                            interventional dataset, the first sample_size_inters samples will be taken.
        max_graph_stacking : int
                             Number of graphs that can maximally evaluated in parallel on the device.
                             If you run out of GPU memory, try to lower this number. It will then
                             evaluate the graph sequentially, which can be slightly slower but uses
                             less memory.
        exclude_inters : list
                         A list of variable indices that should be excluded from sampling interventions
                         from. This should be used to apply ENCO on intervention sets on a subset of
                         the variable set. If None, an empty list will be assumed, i.e., interventions
                         on all variables will be used.
        force_online_data : bool
                            If true, samples new interventional data every step, instead of using a pre-sampled set of data
        interventions_policy : string
                               The name of the interventional policy to use
        interventions_check_max : int
                                  If positive, defines the maximum number of nodes on which the best intervention can be
                                  estimated
        log_grads : bool
                    If true and the interventional policy is gradient-based, the norm of the gradients will be logged.
        theta_coeff : float
        gamma_coeff : float
        policy_softmax_temperature : float
                                     The temperature used for the softmax policies.
        use_global_theta_gradients : bool
        mix_round_robin_every_n_epochs : int
        mix_round_robin_from_epoch : int
        num_hypothetical_graphs : int
                                  Number of graphs used for the var_idx score estimations for the hypothetical_gradients, ait, and bald methods.
        BALD_num_int_samples : int
                               Number of interventional samples per graph used for the MI estimation in BALD
        BALD_batch_size : int
                          The size of the batch of subsequent interventions, used only for the batch_bald interventional policy


        """
        self.parent_enco_object = parent_enco_object  # Not very nice, but sometimes we need to access its methods
        self.model = model
        self.graph = graph
        self.num_batches = num_batches
        self.num_graphs = num_graphs
        self.sample_size_inters = sample_size_inters
        self.batch_size = batch_size
        self.lambda_sparse = lambda_sparse
        self.max_graph_stacking = max_graph_stacking
        self.theta_only_num_graphs = theta_only_num_graphs
        self.theta_coeff = theta_coeff
        self.gamma_coeff = gamma_coeff
        self.inter_vars = []
        self.exclude_inters = exclude_inters if exclude_inters is not None else list()
        self.theta_grad_mask = torch.zeros(self.graph.num_vars, self.graph.num_vars)
        for v in self.exclude_inters:
            self.theta_grad_mask[v, self.exclude_inters] = 1.0

        if dataset is not None:
            self.dataset = dataset
        else:
            assert self.sample_size_inters is not None
            self.dataset = InterventionalDataset(
                self.graph,
                dataset_size=self.sample_size_inters,
                batch_size=self.batch_size,
            )

        if len(self.exclude_inters) > 0:
            print(
                f"Excluding interventions on the following {len(self.exclude_inters)}"
                f" out of {graph.num_vars} variables: "
                f'{", ".join([str(i) for i in sorted(self.exclude_inters)])}'
            )
        self.force_online_data = force_online_data
        self.interventions_policy = interventions_policy
        self.interventions_check_max = interventions_check_max

        self.log_grads = log_grads

        self.policy_softmax_temperature = policy_softmax_temperature
        self.use_global_theta_gradients = use_global_theta_gradients
        self.mix_round_robin_every_n_epochs = mix_round_robin_every_n_epochs
        self.mix_round_robin_from_epoch = mix_round_robin_from_epoch

        self.num_hypothetical_graphs = num_hypothetical_graphs
        self.BALD_num_int_samples = BALD_num_int_samples
        self.BALD_batch_size = BALD_batch_size
        self.dataset_to_save = dataset_to_save

        assert ("trained" in interventions_policy) == bool(trained_model_path)
        self.trained_policy = None
        if trained_model_path:
            self.trained_policy = torch.jit.load(trained_model_path)
            self.trained_policy.eval()
        self.features_fn = None
        if features_fn:
            self.features_fn = getattr(bc_features, features_fn)

        self.only_gamma = only_gamma

        self.sdi_dag_regularization = sdi_dag_regularization

        if sdi_like_interventions:
            assert force_online_data
        self.sdi_like_interventions = sdi_like_interventions

        if interventions_policy in ["trained", "soft_trained"]:
            assert self.interventions_check_max == -1
        self.intervention_strategy = get_strategy_from_name_and_graph_fitting_object(
            interventions_policy.lower(), self
        )
        self.round_robin_strategy = get_strategy_from_name_and_graph_fitting_object(
            "round_robin", self
        )

    def perform_update_step(self, gamma, theta, var_idx=-1, only_theta=False):
        """
        Performs a full update step of the graph fitting stage. We first sample a batch of graphs,
        evaluate them on a interventional data batch, and estimate the gradients for gamma and theta
        based on the log-likelihoods.

        Parameters
        ----------
        gamma : nn.Parameter
                Parameter tensor representing the gamma parameters in ENCO.
        theta : nn.Parameter
                Parameter tensor representing the theta parameters in ENCO.
        var_idx : int
                  Variable on which should be intervened to obtain the update. If none is given, i.e.,
                  a negative value, the variable will be randomly selected.
        only_theta : bool
                     If True, gamma is frozen and the gradients are only estimated for theta. See
                     Appendix D.2 in the paper for details on the gamma freezing stage.
        """
        # Obtain log-likelihood estimates for randomly sampled graph structures
        if not only_theta:
            MC_samp = self.get_MC_samples(
                gamma,
                theta,
                num_batches=self.num_batches,
                num_graphs=self.num_graphs,
                batch_size=self.batch_size,
                var_idx=var_idx,
                mirror_graphs=False,
            )
        else:
            MC_samp = self.get_MC_samples(
                gamma,
                theta,
                num_batches=self.num_batches,
                num_graphs=self.theta_only_num_graphs,
                batch_size=self.batch_size,
                var_idx=var_idx,
                mirror_graphs=True,
            )
        adj_matrices, log_likelihoods, var_idx = MC_samp

        # Determine gradients for gamma and theta
        gamma_grads, theta_grads, theta_mask, _ = self.gradient_estimator(
            adj_matrices, log_likelihoods, gamma, theta, var_idx
        )
        gamma.grad = gamma_grads
        theta.grad = theta_grads

        if self.log_grads:
            magnitude = self.compute_score(gamma_grads, theta_grads)
            lg.NEPTUNE_LOGGER.log(name="grad_l2_step", value=magnitude)
            lg.NEPTUNE_LOGGER.log(name="gamma_grad_l2", value=(gamma_grads**2).mean())
            lg.NEPTUNE_LOGGER.log(name="theta_grad_l2", value=(theta_grads**2).mean())

        return theta_mask, var_idx

    @torch.no_grad()
    def get_MC_samples(
        self,
        gamma,
        theta,
        num_batches,
        num_graphs,
        batch_size,
        var_idx=-1,
        mirror_graphs=False,
        hypothetical_graph=None,
    ):
        """
        Samples and evaluates a batch of graph structures on a batch of interventional data.

        Parameters
        ----------
        gamma : nn.Parameter
                Parameter tensor representing the gamma parameters in ENCO.
        theta : nn.Parameter
                Parameter tensor representing the theta parameters in ENCO.
        num_batches : int
                      Number of batches to use per MC sample.
        num_graphs : int
                     Number of graph structures to sample.
        batch_size : int
                     Size of interventional data batches.
        var_idx : int
                  Variable on which should be intervened to obtain the update. If none is given, i.e.,
                  a negative value, the variable will be randomly selected.
        mirror_graphs : bool
                        This variable should be true if only theta is optimized. In this case, the first
                        half of the graph structure samples is identical to the second half, except that
                        the values of the outgoing edges of the intervened variable are flipped. This
                        allows for more efficient, low-variance gradient estimators. See details in
                        the paper.
        """
        if mirror_graphs:
            assert (
                num_graphs % 2 == 0
            ), "Number of graphs must be divisible by two for mirroring"
        device = self.get_device()

        # Sample data batch
        if hypothetical_graph is not None:
            int_sample = self.get_hypothetical_samples(
                var_idx=var_idx, hypothetical_graph=hypothetical_graph
            )
        else:
            if hasattr(self, "dataset") and not self.force_online_data:
                # Pre-sampled data
                if var_idx < 0:
                    var_idx = self.sample_next_var_idx(gamma=gamma, theta=theta)
                int_sample = torch.cat(
                    [self.dataset.get_batch(var_idx) for _ in range(num_batches)],
                    dim=0,
                ).to(device)

                batch_size = int_sample.shape[0] // num_batches
            else:
                # If no dataset exists, data is newly sampled from the graph

                intervention_dict, var_idx = self.sample_intervention(
                    self.graph,
                    dataset_size=num_batches * batch_size,
                    var_idx=var_idx,
                    gamma=gamma,
                    theta=theta,
                )

                if not self.sdi_like_interventions or isinstance(
                    self.graph.variables[var_idx].prob_dist.prob_func, LeafCategDist
                ):
                    # Regular way. Soft, perfect (uniform) interventions
                    int_sample = self.graph.sample(
                        interventions=intervention_dict,
                        batch_size=num_batches * batch_size,
                        as_array=True,
                    )
                else:
                    # SDI-like interventions. Soft, imperfect, based on randomly changing MLP.
                    saved_mlp = self.graph.variables[
                        var_idx
                    ].prob_dist.prob_func.net.state_dict()
                    self.graph.reinit_distribution(var_idx)

                    int_sample = self.graph.sample(
                        batch_size=num_batches * batch_size,
                        as_array=True,
                    )

                    self.graph.variables[
                        var_idx
                    ].prob_dist.prob_func.net.load_state_dict(saved_mlp)

                int_sample = torch.from_numpy(int_sample).to(device)

        # Split number of graph samples acorss multiple iterations if not all can fit into memory
        num_graphs_list = [
            min(self.max_graph_stacking, num_graphs - i * self.max_graph_stacking)
            for i in range(math.ceil(num_graphs * 1.0 / self.max_graph_stacking))
        ]
        num_graphs_list = [
            (num_graphs_list[i], sum(num_graphs_list[:i]))
            for i in range(len(num_graphs_list))
        ]
        # Tensors needed for sampling
        if self.only_gamma:
            edge_prob = torch.sigmoid(gamma).detach()
        else:
            edge_prob = (torch.sigmoid(gamma) * torch.sigmoid(theta)).detach()
        edge_prob_batch = edge_prob[None].expand(num_graphs, -1, -1)

        # Inner function for sampling a batch of random adjacency matrices from current belief probabilities
        def sample_adj_matrix():
            sample_matrix = torch.bernoulli(edge_prob_batch)
            sample_matrix = sample_matrix * (
                1
                - torch.eye(sample_matrix.shape[-1], device=sample_matrix.device)[None]
            )
            if (
                mirror_graphs
            ):  # First and second half of tensors are identical, except the intervened variable
                sample_matrix[num_graphs // 2 :] = sample_matrix[: num_graphs // 2]
                sample_matrix[num_graphs // 2 :, var_idx] = (
                    1 - sample_matrix[num_graphs // 2 :, var_idx]
                )
                sample_matrix[:, var_idx, var_idx] = 0.0
            return sample_matrix

        # Evaluate log-likelihoods under sampled adjacency matrix and data
        adj_matrices = []
        log_likelihoods = []
        for n_idx in range(num_batches):
            batch = int_sample[n_idx * batch_size : (n_idx + 1) * batch_size]
            if n_idx == 0:
                adj_matrix = sample_adj_matrix()
                adj_matrices.append(adj_matrix)

            for c_idx, (graph_count, start_idx) in enumerate(num_graphs_list):
                adj_matrix_expanded = (
                    adj_matrix[start_idx : start_idx + graph_count, None]
                    .expand(-1, batch_size, -1, -1)
                    .flatten(0, 1)
                )
                batch_exp = batch[None, :].expand(graph_count, -1, -1).flatten(0, 1)
                nll = self.evaluate_likelihoods(batch_exp, adj_matrix_expanded, var_idx)
                nll = nll.reshape(graph_count, batch_size, -1)

                if n_idx == 0:
                    log_likelihoods.append(nll.mean(dim=1))
                else:
                    log_likelihoods[c_idx] += nll.mean(dim=1)

        # Combine all data
        adj_matrices = torch.cat(adj_matrices, dim=0)
        log_likelihoods = torch.cat(log_likelihoods, dim=0) / num_batches

        return adj_matrices, log_likelihoods, var_idx

    @torch.no_grad()
    def gradient_estimator_only_gamma(
        self, adj_matrices, log_likelihoods, gamma, var_idx
    ):
        log_likelihoods = log_likelihoods.unsqueeze(dim=1)
        sig_gamma = torch.sigmoid(gamma)

        norm_probs = F.softmax(-log_likelihoods, dim=0)
        gamma_grads = torch.sum((sig_gamma - adj_matrices) * norm_probs, dim=0)

        # add sparsity regularization
        gamma_grads += self.lambda_sparse * sig_gamma * (1 - sig_gamma)

        # add DAG regularization
        gamma_grads += (
            self.sdi_dag_regularization
            * torch.sinh(sig_gamma * sig_gamma.t())
            * sig_gamma
            * (1 - sig_gamma)
            * sig_gamma.t()
        )

        # zero out gradients on variable with intervention and on diagonal
        gamma_grads[:, var_idx] = 0.0
        gamma_grads[
            torch.arange(gamma_grads.shape[0]), torch.arange(gamma_grads.shape[1])
        ] = 0.0

        return (
            gamma_grads,
            torch.zeros_like(gamma_grads),
            torch.zeros_like(gamma_grads),
            torch.zeros_like(gamma_grads),
        )

    @torch.no_grad()
    def gradient_estimator(self, adj_matrices, log_likelihoods, gamma, theta, var_idx):
        """
        Returns the estimated gradients for gamma and theta. It uses the low-variance gradient estimators
        proposed in Section 3.3 of the paper.

        Parameters
        ----------
        adj_matrices : torch.FloatTensor, shape [batch_size, num_vars, num_vars]
                       The adjacency matrices on which the interventional data has been evaluated on.
        log_likelihoods : torch.FloatTensor, shape [batch_size, num_vars]
                          The average log-likelihood under the adjacency matrices for all variables
                          in the graph.
        gamma : nn.Parameter
                Parameter tensor representing the gamma parameters in ENCO.
        theta : nn.Parameter
                Parameter tensor representing the theta parameters in ENCO.
        var_idx : int
                  Variable on which the intervention was performed.
        """

        if self.only_gamma:
            return self.gradient_estimator_only_gamma(
                adj_matrices, log_likelihoods, gamma, var_idx
            )

        batch_size = adj_matrices.shape[0]
        log_likelihoods = log_likelihoods.unsqueeze(dim=1)

        orient_probs = torch.sigmoid(theta)
        edge_probs = torch.sigmoid(gamma)

        # Gradient calculation
        num_pos = adj_matrices.sum(dim=0)
        num_neg = batch_size - num_pos
        mask = ((num_pos > 0) * (num_neg > 0)).float()
        pos_grads = (log_likelihoods * adj_matrices).sum(dim=0) / num_pos.clamp_(
            min=1e-5
        )
        neg_grads = (log_likelihoods * (1 - adj_matrices)).sum(dim=0) / num_neg.clamp_(
            min=1e-5
        )
        gamma_grads = (
            mask
            * edge_probs
            * (1 - edge_probs)
            * orient_probs
            * (pos_grads - neg_grads + self.lambda_sparse)
        )
        theta_grads = (
            mask
            * orient_probs
            * (1 - orient_probs)
            * edge_probs
            * (pos_grads - neg_grads)
        )

        # Masking gamma for incoming edges to intervened variable
        gamma_grads[:, var_idx] = 0.0
        gamma_grads[
            torch.arange(gamma_grads.shape[0]), torch.arange(gamma_grads.shape[1])
        ] = 0.0

        theta_grads_unmasked = theta_grads.clone()
        theta_grads_unmasked = (
            theta_grads_unmasked - theta_grads_unmasked.transpose(0, 1)
        ) / 2

        # Masking all theta's except the ones with a intervened variable
        theta_zero_mask = self.theta_grad_mask.clone().to(theta_grads.device)
        theta_zero_mask[var_idx] = 1.0
        theta_grads *= theta_zero_mask
        theta_grads -= theta_grads.clone().transpose(0, 1)  # theta_ij = -theta_ji

        # Creating a mask which theta's are actually updated for the optimizer
        # 0.1 multiplier reduces learning rate for variables without interventions
        theta_mask = 0.1 * self.theta_grad_mask.clone().to(theta_grads.device)
        theta_mask[var_idx] = 1.0
        theta_mask[:, var_idx] = 1.0
        theta_mask[var_idx, var_idx] = 0.0

        return gamma_grads, theta_grads, theta_mask, theta_grads_unmasked

    def sample_next_var_idx(self, gamma, theta, intervention_strategy=None):
        """Returns next variable to intervene on."""
        possible_interventions = [
            i for i in range(len(self.graph.variables)) if i not in self.exclude_inters
        ]

        if intervention_strategy is None:
            intervention_strategy = self.intervention_strategy

        if (
            self.mix_round_robin_every_n_epochs > 0
            and (
                (self.parent_enco_object.epoch + 1)
                % self.mix_round_robin_every_n_epochs
                == 0
            )
        ) or (
            self.mix_round_robin_from_epoch > 0
            and self.parent_enco_object.epoch >= self.mix_round_robin_from_epoch
        ):
            var_idx = self.round_robin_strategy.acquire(
                gamma, theta, possible_interventions
            )
        else:
            if 0 < self.interventions_check_max < len(possible_interventions):
                possible_interventions = random.sample(
                    possible_interventions, k=self.interventions_check_max
                )

            var_idx = intervention_strategy.acquire(
                gamma, theta, possible_interventions
            )

        return var_idx

    def sample_intervention(
        self,
        graph,
        dataset_size,
        var_idx=-1,
        gamma=None,
        theta=None,
    ):
        """
        Returns a new data batch for an intervened variable.
        """
        # Select variable to intervene on
        if var_idx < 0:
            var_idx = self.sample_next_var_idx(gamma=gamma, theta=theta)
        var = graph.variables[var_idx]
        # Soft, perfect intervention => replace p(X_n) by random categorical
        # Scale is set to 0.0, which represents a uniform distribution.
        int_dist = _random_categ(size=(var.prob_dist.num_categs,), scale=0.0, axis=-1)
        # Sample from interventional distribution
        value = np.random.multinomial(n=1, pvals=int_dist, size=(dataset_size,))
        value = np.argmax(value, axis=-1)  # One-hot to index
        intervention_dict = {var.name: value}

        return intervention_dict, var_idx

    @torch.no_grad()
    def evaluate_likelihoods(self, int_sample, adj_matrix, var_idx):
        """
        Evaluates the negative log-likelihood of the interventional data batch (int_sample)
        on the given graph structures (adj_matrix) and the intervened variable (var_idx).
        """
        return evaluate_likelihoods_for_model(
            self.model, self.get_device(), int_sample, adj_matrix, var_idx
        )

    def get_device(self):
        return self.model.device

    @torch.no_grad()
    def get_hypothetical_samples(self, var_idx, hypothetical_graph=None):
        assert var_idx >= 0
        num_categs = max([v.prob_dist.num_categs for v in self.graph.variables])
        int_sample = sample_interventionalSamples(
            config=hypothetical_graph,
            target_node=var_idx,
            model=self.model,
            device=self.get_device(),
            nb_categs=num_categs,
            batch_size=self.num_batches * self.batch_size,
        )

        return int_sample

    @torch.no_grad()
    def get_hypothetical_gradient_magnitude(
        self, gamma, theta, var_idx, hypothetical_graph=None
    ):
        adj_matrices, log_likelihoods, var_idx = self.get_MC_samples(
            gamma,
            theta,
            num_batches=self.num_batches,
            num_graphs=self.num_graphs,
            batch_size=self.batch_size,
            var_idx=var_idx,
            mirror_graphs=False,
            hypothetical_graph=hypothetical_graph,
        )

        # Determine gradients for gamma and theta
        (
            gamma_grads,
            theta_grads,
            theta_mask,
            theta_grads_unmasked,
        ) = self.gradient_estimator(
            adj_matrices, log_likelihoods, gamma, theta, var_idx
        )

        if not self.use_global_theta_gradients:
            magnitude = self.compute_score(gamma_grads, theta_grads)
        else:
            magnitude = self.compute_score(gamma_grads, theta_grads_unmasked)

        return magnitude

    def compute_score(self, gamma_grads, theta_grads):
        return (gamma_grads**2).mean() * self.gamma_coeff + (
            theta_grads**2
        ).mean() * self.theta_coeff
