import sys
import time
from copy import deepcopy

import gin
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data

sys.path.append("../")

import causal_discovery.logger as lg
from causal_discovery import graph_fitting, intervention_strategies
from causal_discovery.datasets import (
    DynamicInterventionalDataset,
    ObservationalCategoricalData,
)
from causal_discovery.distribution_fitting import DistributionFitting
from causal_discovery.multivariable_flow import create_continuous_model
from causal_discovery.multivariable_mlp import create_model
from causal_discovery.optimizers import AdamGamma, AdamTheta
from causal_discovery.utils import (
    BestStateDictCheckpointer,
    find_best_acyclic_graph,
    flatten_dict,
    track,
)


@gin.configurable
class NewApproach(object):
    def __init__(
        self,
        graph,
        hidden_dims=[64],
        use_flow_model=False,
        lr_model=5e-3,
        betas_model=(0.9, 0.999),
        weight_decay=0.0,
        lr_gamma=2e-2,
        betas_gamma=(0.9, 0.9),
        lr_theta=1e-1,
        betas_theta=(0.9, 0.999),
        model_iters=1000,
        graph_iters=100,
        batch_size=128,
        GF_num_batches=1,
        GF_num_graphs=100,
        lambda_sparse=0.004,
        latent_threshold=0.35,
        use_theta_only_stage=False,
        theta_only_num_graphs=4,
        theta_only_iters=1000,
        max_graph_stacking=200,
        sample_size_obs=5000,
        init_n_true_relations=0,
        only_gamma=False,
        model_use_embedding=True,
        verbose=False,
        # New parameters (wrt. ENCO parameters)
        num_inner_loop_epochs: int = 10,
        int_data_collection_policy: str = "round_robin",
        int_data_collection_batch_size: int = 32,
        reset_model_structural_params: bool = True,
        reset_model_MLPs: bool = True,
        use_best_model_to_select_intervention: bool = False,
        reduce_length_of_first_super_epochs: bool = False,
        **graph_fitting_kwargs,
    ):
        """
        Creates an object that performs causal graph discovery:

        Parameters
        ----------
        num_inner_loop_epochs: int
            Number of epochs of inner (ENCO like) optimization procedure.
        int_data_collection_policy: str
            Interventional data collection policy.
        int_data_collection_batch_size: int
            Size of interventional data batch acquired every super-epoch.
        reset_model_structural_params: bool
            Wheather to reset structural parameters every super-epoch.
        reset_model_MLPs: bool
            Whether to reset functional parameters every super-epoch.
        use_best_model_to_select_intervention: bool
            Whether to use best model (wrt. the SHD metric) from each
            super-epoch to select next intervention target. If set to False the
            last model is used.
        """
        self.graph = graph
        self.num_vars = graph.num_vars
        self.num_inner_loop_epochs = num_inner_loop_epochs
        # Create observational dataset
        obs_dataset = ObservationalCategoricalData(graph, dataset_size=sample_size_obs)
        obs_data_loader = data.DataLoader(
            obs_dataset, batch_size=batch_size, shuffle=True, drop_last=True
        )
        self.int_dataset = DynamicInterventionalDataset(self.graph, batch_size)
        # Create neural networks for fitting the conditional distributions
        if graph.is_categorical:
            num_categs = max([v.prob_dist.num_categs for v in graph.variables])
            self.model = create_model(
                num_vars=self.num_vars,
                num_categs=num_categs,
                hidden_dims=hidden_dims,
                use_embedding=model_use_embedding,
            )
        else:
            self.model = create_continuous_model(
                num_vars=self.num_vars,
                hidden_dims=hidden_dims,
                use_flow_model=use_flow_model,
            )
        model_optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=lr_model,
            betas=betas_model,
            weight_decay=weight_decay,
        )
        # Initialize graph parameters
        self.true_adj_matrix = torch.from_numpy(graph.adj_matrix).bool()
        self.lr_gamma, self.betas_gamma, self.lr_theta, self.betas_theta = (
            lr_gamma,
            betas_gamma,
            lr_theta,
            betas_theta,
        )
        self.init_n_true_relations = init_n_true_relations
        self.spawn_graph_params(
            self.num_vars,
            self.lr_gamma,
            self.betas_gamma,
            self.lr_theta,
            self.betas_theta,
            n_true_relations=self.init_n_true_relations,
        )
        # Initialize distribution and graph fitting modules
        self.distribution_fitting_module = DistributionFitting(
            model=self.model,
            optimizer=model_optimizer,
            data_loader=obs_data_loader,
            graph=graph,
            force_online_data=False,
        )
        self.only_gamma = only_gamma
        assert graph_fitting_kwargs.get("force_online_data", False) == False
        assert graph_fitting_kwargs["interventions_policy"] == "nonempty_round_robin"
        self.graph_fitting_module = graph_fitting.GraphFitting(
            parent_enco_object=self,
            model=self.model,
            graph=graph,
            num_batches=GF_num_batches,
            num_graphs=GF_num_graphs,
            theta_only_num_graphs=theta_only_num_graphs,
            batch_size=batch_size,
            lambda_sparse=lambda_sparse,
            max_graph_stacking=max_graph_stacking,
            exclude_inters=self.graph.exclude_inters,
            only_gamma=only_gamma,
            dataset=self.int_dataset,
            **graph_fitting_kwargs,
        )
        # Those acquisition methods can't work in the "new approach".
        if (
            int_data_collection_policy == "gradients_l2"
            or int_data_collection_policy == "soft_gradients_l2"
            or int_data_collection_policy == "ce_shd_reduction"
        ):
            raise NotImplementedError(
                f"{self.__class__.__name__} graph discovery method is currently"
                " not working with {int_data_collection_policy} intervention"
                " selection policy."
            )
        self.int_data_collection_strategy = (
            intervention_strategies.get_strategy_from_name_and_graph_fitting_object(
                int_data_collection_policy, self.graph_fitting_module
            )
        )
        self.int_data_collection_batch_size = int_data_collection_batch_size
        # Save other hyperparameters
        self.model_iters = model_iters
        self.graph_iters = graph_iters
        self.reset_model_structural_params = reset_model_structural_params
        self.reset_model_MLPs = reset_model_MLPs
        self.use_theta_only_stage = use_theta_only_stage
        self.theta_only_iters = theta_only_iters
        self.latent_threshold = latent_threshold
        self.true_node_relations = torch.from_numpy(graph.node_relations)
        self.metric_log = []
        self.iter_time = -1
        self.dist_fit_time = -1
        self.use_best_model_to_select_intervention = (
            use_best_model_to_select_intervention
        )
        self.reduce_length_of_first_super_epochs = reduce_length_of_first_super_epochs
        if self.use_best_model_to_select_intervention:
            self.best_state_dict_checkpointer = BestStateDictCheckpointer()

        # Some debugging info for user
        print(f"Distribution fitting model:\n{str(self.model)}")
        print(f"Dataset size:\n- Observational: {len(obs_dataset)}")

        assert not (use_theta_only_stage and only_gamma)
        self.verbose = verbose

    def spawn_graph_params(
        self, num_vars, lr_gamma, betas_gamma, lr_theta, betas_theta, n_true_relations
    ):
        """
        Create gamma and theta parameters and their optmizers, prepare true pairs.
        """
        self.gamma = nn.Parameter(torch.zeros(num_vars, num_vars))
        # For latent confounders, we need to track interventional and observational gradients separat => different opt
        if self.graph.num_latents > 0:
            self.gamma_optimizer = AdamGamma(
                self.gamma, lr=lr_gamma, beta1=betas_gamma[0], beta2=betas_gamma[1]
            )
        else:
            self.gamma_optimizer = torch.optim.Adam(
                [self.gamma], lr=lr_gamma, betas=betas_gamma
            )

        self.theta = nn.Parameter(torch.zeros(num_vars, num_vars))
        self.theta_optimizer = AdamTheta(
            self.theta, lr=lr_theta, beta1=betas_theta[0], beta2=betas_theta[1]
        )

        pairs = np.array([(a, b) for a in range(num_vars) for b in range(a)])
        self.true_pairs = pairs[
            np.random.choice(len(pairs), size=n_true_relations, replace=False)
        ]

    def init_graph_params(self):
        """
        Initializes gamma and theta parameters.
        """
        self.gamma.data[:] = torch.zeros(
            self.num_vars, self.num_vars
        )  # Init with zero => prob 0.5
        self.gamma.data[
            torch.arange(self.num_vars), torch.arange(self.num_vars)
        ] = -9e15  # Mask diagonal

        self.theta.data[:] = torch.zeros(
            self.num_vars, self.num_vars
        )  # Init with zero => prob 0.5

        pairs = self.true_pairs
        infill_matrix = (
            self.true_adj_matrix.float() - 0.5
        ) * 10  # get values -5/5 instead of 0/1
        infill_matrix = infill_matrix.to(self.gamma.device)
        self.gamma.data[pairs[:, 0], pairs[:, 1]] = infill_matrix[
            pairs[:, 0], pairs[:, 1]
        ]
        self.gamma.data[pairs[:, 1], pairs[:, 0]] = infill_matrix[
            pairs[:, 1], pairs[:, 0]
        ]
        self.theta.data[pairs[:, 0], pairs[:, 1]] = infill_matrix[
            pairs[:, 0], pairs[:, 1]
        ]
        self.theta.data[pairs[:, 1], pairs[:, 0]] = infill_matrix[
            pairs[:, 1], pairs[:, 0]
        ]

    def reset_model_parameters(self):
        if self.reset_model_structural_params:
            self.init_graph_params()

        @torch.no_grad()
        def weight_reset(m: nn.Module):
            """refs - https://discuss.pytorch.org/t/reset-model-weights/19180/7"""
            # - check if the current module has reset_parameters & if it's callabed called it on m
            reset_parameters = getattr(m, "reset_parameters", None)
            if callable(reset_parameters):
                m.reset_parameters()

        if self.reset_model_MLPs:
            self.model.apply(fn=weight_reset)

    def choose_number_of_inner_epochs(self, outer_epoch):
        """
        This is simple heuristic for reducing number of updates on early epochs.
        It is a linear function of interventional data collected so far.
        It has been verified experimentally.
        """
        MAGIC_CONSTANT = 100 / self.graph_iters
        epochs = min(self.num_inner_loop_epochs, int(outer_epoch * MAGIC_CONSTANT))
        epochs = max(1, epochs)
        return epochs

    def discover_graph(self, num_epochs=30, stop_early=False):
        """Main training function. It conducts graph fitting and data collection."""
        # Fit distribution as a basic model
        self.reset_model_parameters()
        self.distribution_fitting_step()
        for super_epoch in range(num_epochs):
            # Select intervention
            if self.use_best_model_to_select_intervention and super_epoch > 0:
                self.load_state_dict(
                    self.best_state_dict_checkpointer.get_best_state_dict()
                )
            var_idx = self.graph_fitting_module.sample_next_var_idx(
                self.gamma,
                self.theta,
                intervention_strategy=self.int_data_collection_strategy,
            )
            lg.NEPTUNE_LOGGER.log(name="var_idx", value=var_idx)
            lg.NEPTUNE_LOGGER.log(
                name="ds/var_idx", value=var_idx, step_name="data_samples"
            )
            self.int_dataset.add_batch(var_idx, self.int_data_collection_batch_size)
            lg.NEPTUNE_LOGGER.log(
                name="data_samples",
                value=(super_epoch + 1) * self.int_data_collection_batch_size,
            )
            lg.NEPTUNE_LOGGER.bump(
                self.int_data_collection_batch_size, step_name="data_samples"
            )

            # Reset and train model on new data
            self.reset_model_parameters()
            if self.reduce_length_of_first_super_epochs:
                epochs = self.choose_number_of_inner_epochs(super_epoch)
            else:
                epochs = self.num_inner_loop_epochs
            self.fit_graph(
                epochs,
                stop_early,
                start_epochs_from=super_epoch * self.num_inner_loop_epochs,
            )

        return self.get_binary_adjmatrix()

    def fit_graph(self, num_epochs=30, stop_early=False, start_epochs_from=0):
        """
        Inner training function. It starts the loop of distribution and graph fitting.
        Returns the predicted binary adjacency matrix.
        """
        num_stops = 0
        if self.use_best_model_to_select_intervention:
            self.best_state_dict_checkpointer.reset()
        for epoch in track(
            range(start_epochs_from, start_epochs_from + num_epochs),
            leave=False,
            desc="Epoch loop",
        ):
            self.epoch = epoch
            start_time = time.time()
            # Update Model
            self.distribution_fitting_step()
            self.dist_fit_time = time.time() - start_time
            if self.int_dataset.size() > 0:
                # Update graph parameters
                self.graph_fitting_step()
            self.iter_time = time.time() - start_time
            # Print stats
            metrics = self.print_graph_statistics(epoch=epoch + 1, log_metrics=True)
            if self.use_best_model_to_select_intervention:
                self.best_state_dict_checkpointer.update(
                    metrics["SHD"], self.get_state_dict()
                )
            # Early stopping if perfect reconstruction for 5 epochs (for faster prototyping)
            if stop_early and self.is_prediction_correct():
                num_stops += 1
                if num_stops >= 5:
                    print("Stopping early due to perfect discovery")
                    break
            else:
                num_stops = 0
        lg.NEPTUNE_LOGGER.log(
            name="ds/SHD", value=metrics["SHD"], step_name="data_samples"
        )
        return self.get_binary_adjmatrix()

    def distribution_fitting_step(self):
        """
        Performs on iteration of distribution fitting.
        """
        # Probabilities to sample input masks from
        if self.only_gamma:
            sample_matrix = torch.sigmoid(self.gamma)
        else:
            sample_matrix = torch.sigmoid(self.gamma) * torch.sigmoid(self.theta)
        # Update model in a loop
        t = track(
            range(self.model_iters), leave=False, desc="Distribution fitting loop"
        )
        for _ in t:
            loss = self.distribution_fitting_module.perform_update_step(
                sample_matrix=sample_matrix
            )
            if hasattr(t, "set_description"):
                t.set_description("Model update loop, loss: %4.2f" % loss)

    def graph_fitting_step(self):
        """
        Performs on iteration of graph fitting.
        """
        # For large graphs, freeze gamma in every second graph fitting stage
        only_theta = self.use_theta_only_stage and self.epoch % 2 == 0
        iters = self.graph_iters if not only_theta else self.theta_only_iters
        # Update gamma and theta in a loop
        for _ in track(range(iters), leave=False, desc="Graph fitting loop"):
            self.graph_fitting_single_iteration()

    def graph_fitting_single_iteration(self, var_idx=-1, suppress_logging=False):
        only_theta = self.use_theta_only_stage and self.epoch % 2 == 0

        self.gamma_optimizer.zero_grad()
        self.theta_optimizer.zero_grad()
        theta_mask, var_idx = self.graph_fitting_module.perform_update_step(
            self.gamma, self.theta, var_idx=var_idx, only_theta=only_theta
        )

        gradient_magnitude = float(
            self.graph_fitting_module.compute_score(self.gamma.grad, self.theta.grad)
        )

        if not only_theta:  # In the gamma freezing stages, we do not update gamma
            if isinstance(self.gamma_optimizer, AdamGamma):
                self.gamma_optimizer.step(var_idx)
            else:
                self.gamma_optimizer.step()
        self.theta_optimizer.step(theta_mask)

        m = self.get_metrics()
        m["gradient_magnitude"] = gradient_magnitude

        if not suppress_logging:
            for k, v in flatten_dict(m).items():
                lg.NEPTUNE_LOGGER.log(name=k, value=v)
            lg.NEPTUNE_LOGGER.bump()

        return m

    def get_binary_adjmatrix(self):
        """
        Returns the predicted, binary adjacency matrix of the causal graph.
        """
        binary_gamma = self.gamma > 0.0
        binary_theta = self.theta > 0.0
        if self.only_gamma:
            A = binary_gamma
        else:
            A = binary_gamma * binary_theta
        # If we consider latent confounders, we mask all edges that have a confounder score greater than the threshold
        if self.graph.num_latents > 0:
            A = A * (self.get_confounder_scores() < self.latent_threshold)

        return (A == 1).cpu()

    def get_acyclic_adjmatrix(self):
        """
        Returns the predicted, acyclic adjacency matrix of the causal graph.
        """
        A = find_best_acyclic_graph(
            gamma=torch.sigmoid(self.gamma), theta=torch.sigmoid(self.theta)
        )
        return A.cpu()

    def is_prediction_correct(self):
        """
        Returns true if the prediction corresponds to the correct, underlying causal graph. Otherwise false.
        If latent confounders exist, those need to be correct as well to return true.
        """
        correct_pred = (self.get_binary_adjmatrix() == self.true_adj_matrix).all()
        if self.graph.num_latents > 0:
            conf_metrics = self.get_confounder_metrics()
            correct_pred = correct_pred and (
                (conf_metrics["FP"] + conf_metrics["FN"]) == 0
            )
        return correct_pred

    def get_confounder_scores(self):
        """
        Returns a matrix of shape [num_vars, num_vars] where the element (i,j) represents the confounder score
        between the variable pair X_i and X_j, i.e., lc(X_i,X_j).
        """
        if isinstance(self.gamma_optimizer, AdamGamma):
            gamma_obs_sig, gamma_int_sig = torch.unbind(
                torch.sigmoid(self.gamma_optimizer.updates), dim=-1
            )
            l_score = gamma_obs_sig * (1 - gamma_int_sig)
            l_score *= l_score.T  # Scores are a symmetric matrix
        else:
            l_score = torch.zeros_like(self.gamma)
        return l_score

    @torch.no_grad()
    def print_graph_statistics(self, epoch=-1, log_metrics=False, m=None):
        """
        Prints statistics and metrics of the current graph prediction. It is executed
        during training to track the training progress.
        """
        if m is None:
            m = self.get_metrics()
        if log_metrics:
            if epoch > 0:
                m["epoch"] = epoch
            self.metric_log.append(m)

        if self.verbose:
            if epoch > 0:
                print("--- [EPOCH %i] ---" % epoch)
            print(
                "Graph - SHD: %i, Recall: %4.2f%%, Precision: %4.2f%% (TP=%i,FP=%i,FN=%i,TN=%i)"
                % (
                    m["SHD"],
                    100.0 * m["recall"],
                    100.0 * m["precision"],
                    m["TP"],
                    m["FP"],
                    m["FN"],
                    m["TN"],
                )
            )
            print(
                "      -> FP:",
                ", ".join(
                    ["%s=%i" % (key, m["FP_details"][key]) for key in m["FP_details"]]
                ),
            )
            print(
                "Theta - Orientation accuracy: %4.2f%% (TP=%i,FN=%i)"
                % (m["orient"]["acc"] * 100.0, m["orient"]["TP"], m["orient"]["FN"])
            )

            if self.graph.num_latents > 0 and "confounders" in m:
                print(
                    "Latent confounders - TP=%i,FP=%i,FN=%i,TN=%i"
                    % (
                        m["confounders"]["TP"],
                        m["confounders"]["FP"],
                        m["confounders"]["FN"],
                        m["confounders"]["TN"],
                    )
                )

        if (
            epoch > 0 and self.num_vars >= 10
        ):  # For large graphs, we print runtime statistics for better time estimates
            gpu_mem = (
                torch.cuda.max_memory_allocated(device="cuda:0") / 1.0e9
                if torch.cuda.is_available()
                else -1
            )
            if self.verbose:
                print(
                    "-> Iteration time: %imin %is"
                    % (int(self.iter_time) // 60, int(self.iter_time) % 60)
                )
                print(
                    "-> Fitting time: %imin %is"
                    % (int(self.dist_fit_time) // 60, int(self.dist_fit_time) % 60)
                )
                print("-> Used GPU memory: %4.2fGB" % (gpu_mem))
            stats = {
                "iteration_time": self.iter_time,
                "fitting_time": self.dist_fit_time,
                "gpu_mem": gpu_mem,
                "epoch": epoch,
            }

            # No need to log other metrics here, we log them more often
            for k, v in flatten_dict(stats).items():
                lg.NEPTUNE_LOGGER.log(name=k, value=v)

            m.update(stats)
        return m

    @torch.no_grad()
    def get_cross_entropy_shd(self):
        if self.only_gamma:
            probs = torch.sigmoid(self.gamma)
        else:
            probs = torch.sigmoid(self.gamma) * torch.sigmoid(self.theta)
        ce = torch.nn.functional.binary_cross_entropy(
            probs.cpu(), self.true_adj_matrix.float()
        )
        return float(ce)

    @torch.no_grad()
    def get_metrics(self, enforce_acyclic_graph=False):
        """
        Returns a dictionary with detailed metrics comparing the current prediction to the ground truth graph.
        """
        # Standard metrics (TP,TN,FP,FN) for edge prediction
        binary_matrix = self.get_binary_adjmatrix()
        if enforce_acyclic_graph:
            assert not self.only_gamma
            binary_matrix = self.get_acyclic_adjmatrix()
        else:
            binary_matrix = self.get_binary_adjmatrix()
        false_positives = torch.logical_and(binary_matrix, ~self.true_adj_matrix)
        false_negatives = torch.logical_and(~binary_matrix, self.true_adj_matrix)
        TP = torch.logical_and(binary_matrix, self.true_adj_matrix).float().sum().item()
        TN = (
            torch.logical_and(~binary_matrix, ~self.true_adj_matrix)
            .float()
            .sum()
            .item()
        )
        FP = false_positives.float().sum().item()
        FN = false_negatives.float().sum().item()
        TN = (
            TN - self.gamma.shape[-1]
        )  # Remove diagonal as those are not being predicted
        recall = TP / max(TP + FN, 1e-5)
        precision = TP / max(TP + FP, 1e-5)
        # Structural Hamming Distance score
        rev = torch.logical_and(binary_matrix, self.true_adj_matrix.T)
        num_revs = rev.float().sum().item()
        SHD = (
            false_positives + false_negatives + rev + rev.T
        ).float().sum().item() - num_revs

        ce_shd = self.get_cross_entropy_shd()

        # Get details on False Positives (what relations have the nodes of the false positives?)
        FP_elems = torch.where(torch.logical_and(binary_matrix, ~self.true_adj_matrix))
        FP_relations = self.true_node_relations[FP_elems]
        FP_dict = {
            "ancestors": (FP_relations == -1).sum().item(),  # i->j => j is a child of i
            "descendants": (FP_relations == 1).sum().item(),
            "confounders": (FP_relations == 2).sum().item(),
            "independents": (FP_relations == 0).sum().item(),
        }

        # Details on orientation prediction of theta, independent of gamma
        orient_TP = (
            torch.logical_and(self.true_adj_matrix == 1, self.theta.cpu() > 0.0)
            .float()
            .sum()
            .item()
        )
        orient_FN = (
            torch.logical_and(self.true_adj_matrix == 1, self.theta.cpu() <= 0.0)
            .float()
            .sum()
            .item()
        )
        orient_acc = orient_TP / max(1e-5, orient_TP + orient_FN)
        orient_dict = {"TP": int(orient_TP), "FN": int(orient_FN), "acc": orient_acc}

        # Summarizing all results in single dictionary
        metrics = {
            "TP": int(TP),
            "TN": int(TN),
            "FP": int(FP),
            "FN": int(FN),
            "SHD": int(SHD),
            "ce_shd": ce_shd,
            "reverse": int(num_revs),
            "recall": recall,
            "precision": precision,
            "FP_details": FP_dict,
            "orient": orient_dict,
        }

        if self.graph.num_latents > 0 and not enforce_acyclic_graph:
            metrics["confounders"] = self.get_confounder_metrics()
        return metrics

    @torch.no_grad()
    def get_confounder_metrics(self):
        """
        Returns metrics for detecting the latent confounders in the graph.
        """
        # Determine TP, FP, FN, and TN for latent confounder prediction
        l_score = self.get_confounder_scores()
        l_score = torch.triu(l_score, diagonal=1)
        l_predict = torch.stack(torch.where(l_score >= self.latent_threshold), dim=-1)
        l_predict = l_predict.cpu().numpy()
        match = (
            (l_predict[:, None, :] == self.graph.latents[None, :, 1:])
            .all(axis=-1)
            .any(axis=1)
            .astype(np.int32)
        )
        TP_latent = match.sum()
        FP_latent = (1 - match).sum()
        FN_latent = self.graph.num_latents - TP_latent
        num_pairs = self.num_vars * (self.num_vars - 1)
        TN_latent = num_pairs - (TP_latent + FP_latent + FN_latent)

        metrics_conf = {
            "TP": int(TP_latent),
            "FP": int(FP_latent),
            "FN": int(FN_latent),
            "TN": int(TN_latent),
            "scores": l_score[self.graph.latents[:, 1], self.graph.latents[:, 2]]
            .cpu()
            .numpy()
            .tolist(),
        }
        return metrics_conf

    def get_state_dict(self):
        """
        Returns a dictionary of all important parameters to save the current prediction status.
        """
        state_dict = {
            "gamma": self.gamma.data.detach(),
            "theta": self.theta.data.detach(),
            "model": self.distribution_fitting_module.model.state_dict(),
        }
        return state_dict

    def load_state_dict(self, state_dict):
        """
        Loads parameters from a state dictionary, obtained from 'get_state_dict'.
        """
        self.gamma.data = state_dict["gamma"]
        self.theta.data = state_dict["theta"]
        self.distribution_fitting_module.model.load_state_dict(state_dict["model"])

    def to(self, device):
        """
        Moves all PyTorch parameters to a specified device.
        """
        self.distribution_fitting_module.model.to(device)
        self.model.to(device)
        self.gamma.data = self.gamma.data.to(device)
        self.theta.data = self.theta.data.to(device)
        self.theta_optimizer.to(device)
        if hasattr(self.gamma_optimizer, "to"):
            self.gamma_optimizer.to(device)

    def save_state_and_optimizers(self):
        # save gamma, theta and optimizers
        saved_gamma = deepcopy(self.gamma.data)
        saved_theta = deepcopy(self.theta.data)
        if isinstance(self.gamma_optimizer, AdamGamma):
            saved_gamma_optimizer = deepcopy(self.gamma_optimizer)
        else:
            saved_gamma_optimizer = self.gamma_optimizer.state_dict()
        assert isinstance(self.theta_optimizer, AdamTheta)
        saved_theta_optimizer = deepcopy(self.theta_optimizer)

        return saved_gamma, saved_theta, saved_gamma_optimizer, saved_theta_optimizer

    def load_state_and_optimizers(
        self, saved_gamma, saved_theta, saved_gamma_optimizer, saved_theta_optimizer
    ):
        self.gamma.data[:] = deepcopy(saved_gamma)
        self.theta.data[:] = deepcopy(saved_theta)

        if isinstance(self.gamma_optimizer, AdamGamma):
            self.gamma_optimizer = deepcopy(saved_gamma_optimizer)
        else:
            self.gamma_optimizer.load_state_dict(saved_gamma_optimizer)
        assert isinstance(self.theta_optimizer, AdamTheta)
        self.theta_optimizer = deepcopy(saved_theta_optimizer)

        self.gamma_optimizer.zero_grad()
        self.theta_optimizer.zero_grad()
