import abc
import math
import pathlib
from copy import deepcopy
from collections import defaultdict
from typing import Optional, List, Tuple, Dict, Union

import networkx as nx
from nflows.transforms import MaskedAffineAutoregressiveTransform
from tqdm import tqdm

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.distributions as D

import nflows.flows
import nflows.distributions
import nflows.transforms
import nflows.utils
import nflows.nn.nets.resnet

import gpytorch
import gpytorch.distributions
import gpytorch.kernels
import gpytorch.likelihoods
import gpytorch.means
import gpytorch.mlls
from gpytorch.models import ExactGP

import nfmc_jax.sinf.SINF
import nfmc_jax.sinf.optimize
import nfmc_jax.sinf.GIS
from nfmc_jax.flows.checks import check_latent_space_valid
from nfmc_jax.flows.debug import FlowDebugger
from nfmc_jax.hierarchical.ecoflow import build_ecoflow

from nfmc_jax.sinf.SlicedWasserstein import Stiefel_SGD
from nfmc_jax.sinf.SINF import SlicedTransport, whiten
from nfmc_jax.sinf.optimize import lossfunc3, regularization

from nfmc_jax.TRENF.TRENF import NF
from nfmc_jax.utils.torch_distributions import gaussian_log_prob


class FlowInterface(abc.ABC):
    def __init__(self, debugger: Optional[FlowDebugger] = None):
        """
        Interface class that lets algorithms like DLA access key flow functionalities and update the flow.
        The intended use is to create a specific implementation.
        """
        self.flow = None
        self.debugger = debugger
        self._auto_step: int = 0  # Training epoch counter

        self._z_samples: Optional[torch.Tensor] = None  # For debugging purposes

    @property
    def debug(self):
        return self.debugger is not None

    def debug_step(self):
        if self.debug:
            self.debugger.step()

    def debug_animate(self):
        if self.debug:
            self.debugger.animate()

    @torch.no_grad()
    def add_debug_data(self, x_train: torch.Tensor, x_val: torch.Tensor = None, scalars: dict = None):
        if not self.debug:
            return

        if scalars is None:
            scalars = dict()

        # Check debugger configurations
        writer_config = self.debugger.file_writer.config
        visualizer_config = self.debugger.visualizer.config

        if writer_config.write_scalars or visualizer_config.plot_scalars:
            self.debugger.add_scalar('logq_train', self.logq(x_train).mean())
            if x_val is not None:
                self.debugger.add_scalar('logq_val', self.logq(x_val).mean())
            for key, value in scalars.items():
                self.debugger.add_scalar(key, value)

        if writer_config.write_samples:
            # TODO create necessary method in file_writer
            pass

        if writer_config.write_training_paths or visualizer_config.plot_training_paths:
            paths = self.forward_paths(x_train)
            self.debugger.add_training_paths(paths)

        if writer_config.write_training_reconstructions or visualizer_config.plot_training_reconstructions:
            z_train = self.forward(x_train)
            x_train_reconstructed = self.inverse(z_train)
            self.debugger.add_training_reconstructions(x_train, x_train_reconstructed)

        if (writer_config.write_validation_paths or visualizer_config.plot_validation_paths) and x_val is not None:
            paths = self.forward_paths(x_val)
            self.debugger.add_validation_paths(paths)

        if (
                writer_config.write_validation_reconstructions or visualizer_config.plot_validation_reconstructions) and x_val is not None:
            z_val = self.forward(x_val)
            x_val_reconstructed = self.inverse(z_val)
            self.debugger.add_validation_reconstructions(x_val, x_val_reconstructed)

        if writer_config.write_generative_paths or visualizer_config.plot_generative_paths:
            if self._z_samples is None:
                n_dim = x_train.shape[1]
                self._z_samples = torch.randn(visualizer_config.n_latent_points, n_dim)
            x_generated = self.inverse_paths(self._z_samples)
            self.debugger.add_generative_paths(x_generated)

    @abc.abstractmethod
    def create_flow(self, *args, **kwargs):
        """
        Create the normalizing flow object.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def train_flow(self, *args, **kwargs):
        """
        Train the normalizing flow.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Sample points in data space from the normalizing flow.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def logq(self, x: torch.Tensor) -> torch.Tensor:
        """
        Evaluate the log-density of the flow at the supplied sample locations.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        """
        Perform the forward pass, returning latent samples and logj.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        """
        Perform the inverse pass, returning data space samples and logj.
        """
        raise NotImplementedError

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Push data space samples to latent space.

        :param x: points in data space with shape (n, d).
        :return: points in latent space with shape (n, d).
        """
        return self.forward_with_logj(x, **kwargs)[0]

    def inverse(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Push latent space samples to data space.

        :param z: points in latent space with shape (n, d).
        :return: points in data space with shape (n, d).
        """
        return self.inverse_with_logj(z, **kwargs)[0]

    def logj_forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute the log of the Jacobian determinant at points x (thereby pushing x forward through the flow).
        """
        return self.forward_with_logj(x, **kwargs)[1]

    def logj_backward(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute the log of the Jacobian determinant at points z (thereby pushing z backward through the flow).
        """
        return self.inverse_with_logj(z, **kwargs)[1]

    def grad_x_logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute gradient of logq wrt x i.e. d/dx (log q(x)).

        This means computing the density of x by:
        * pushing it into latent space;
        * evaluating the density in the latent space;
        * accounting for the change of volume with the Jacobian determinant.
        After obtaining a scalar value (log density of x), we take its derivative wrt x.

        :param x: points in data space with shape (n, d).
        :return: gradient with shape (n, d).
        """
        raise NotImplementedError

    def grad_z_logp(self, z: torch.Tensor, grad_wrt_x, **kwargs) -> torch.Tensor:
        """
        Compute the gradient of logp wrt z.

        We are computing grad_x(logp)dx/dz, so grad_wrt_x should be grad_x(logp).
        Also: grad_z logp = d/dz logp = d/dx dx/dz logp = d/dx logp * dx/dz = grad_wrt_x * dx/dz.

        :param z: points in latent space with shape (n, d).
        :param grad_wrt_x: vector to be used in the Jacobian-vector product.
        :return: gradient of logq wrt z.
        """
        raise NotImplementedError

    def grad_z_logj(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute the gradient of logj wrt z.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        """
        Push x through the flow in the forward direction and return all intermediate states in a tensor.
        If the flow has L layers and x has shape (n, n_dim), then the output has shape (L, n, n_dim).
        """
        raise NotImplementedError

    @abc.abstractmethod
    def inverse_paths(self, z: torch.Tensor):
        """
        Push x through the flow in the inverse direction and return all intermediate states in a tensor.
        If the flow has L layers and z has shape (n, n_dim), then the output has shape (L, n, n_dim).
        """
        raise NotImplementedError


class TorchFlowInterface(FlowInterface, abc.ABC):
    def __init__(self,
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None,
                 n_dim: int = None):
        super().__init__(debugger=debugger)
        self.device = device
        self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs
        self.n_dim = n_dim

    def sample_with_logq(self, n_samples: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Sample points from the flow and obtain their logq.
        """
        if self.n_dim is None:
            raise ValueError("n_dim must be set, but got None")
        z = torch.randn(n_samples, self.n_dim, device=self.device)
        x, logj_backward = self.inverse_with_logj(z)
        logq = -self.n_dim / 2. * torch.log(torch.tensor(2. * np.pi)) - torch.sum(z ** 2, dim=1) / 2 - logj_backward
        return x, logq

    def grad_x_logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute gradient of logq wrt x i.e. d/dx (log q(x)).

        This means computing the density of x by:
        * pushing it into latent space;
        * evaluating the density in the latent space;
        * accounting for the change of volume with the Jacobian determinant.
        After obtaining a scalar value (log density of x), we take its derivative wrt x.

        :param x: points in data space with shape (n, d).
        :return: gradient with shape (n, d).
        """

        x_tmp = x.to(self.device)
        x_tmp.requires_grad_(True)
        grad_x = torch.autograd.grad(torch.sum(self.logq(x_tmp)), x_tmp)[0]
        x_tmp.requires_grad_(False)
        return grad_x.detach()

    def grad_z_logp(self, z: torch.Tensor, grad_wrt_x, **kwargs) -> torch.Tensor:
        """
        Compute the gradient of logp wrt z.

        We are computing grad_x(logp)dx/dz, so grad_wrt_x should be grad_x(logp).
        Also: grad_z logp = d/dz logp = d/dx dx/dz logp = d/dx logp * dx/dz = grad_wrt_x * dx/dz.

        :param z: points in latent space with shape (n, d).
        :param grad_wrt_x: vector to be used in the Jacobian-vector product.
        :return: gradient of logq wrt z.
        """
        z_tmp = z.to(self.device)
        grad_wrt_x_tmp = grad_wrt_x.to(self.device)

        z_tmp.requires_grad_(True)
        x = self.inverse(z_tmp)
        grad_z = torch.autograd.grad(x, z_tmp, grad_outputs=grad_wrt_x_tmp.to(self.device))[0]
        z_tmp.detach_()

        return grad_z.detach()

    def grad_z_logj(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Compute the gradient of logj_forward wrt z.
        """
        z_tmp = z.to(self.device)

        z_tmp.requires_grad_(True)
        logj_backward = self.logj_backward(z_tmp)
        log_jacobian = torch.sum(-logj_backward)
        grad_jacobian = torch.autograd.grad(log_jacobian, z_tmp)[0]
        z_tmp.requires_grad_(False)
        return grad_jacobian.detach()

    @abc.abstractmethod
    def get_optimizers(self):
        """
        Get optimizers for flow training.

        :return: list of torch optimizers, corresponding to different parameter groups of the flow.
        """
        raise NotImplementedError


class RQNSFInterface(TorchFlowInterface):
    def __init__(self,
                 n_dim: int,
                 n_layers: int = 5,
                 n_hidden: int = 5,
                 n_blocks_per_layer: int = 2,
                 n_bins: int = 64,
                 activation: callable = torch.sigmoid,
                 dropout_prob: float = 0.0,
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        if n_dim < 2:
            # The nflows implementation of RQ-NSF does not support 1D data
            raise ValueError(f"n_dim must be greater than 2 but got {n_dim}")
        super().__init__(n_dim=n_dim, debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)
        self.n_dim = n_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_blocks_per_layer = n_blocks_per_layer
        self.n_bins = n_bins
        self.activation = activation
        self.dropout_prob = dropout_prob
        self.base_dist = nflows.distributions.normal.StandardNormal(shape=[self.n_dim])
        self.flow: Optional[nflows.flows.base.Flow] = None

    def create_flow(self, *args, **kwargs):
        """
        Create a RQ-NSF object.

        Reference: https://github.com/bayesiains/nsf/blob/master/experiments/plane.py
        """
        transforms = []
        for layer_index in range(self.n_layers):
            layer = nflows.transforms.PiecewiseRationalQuadraticCouplingTransform(
                mask=nflows.utils.create_alternating_binary_mask(
                    features=self.n_dim,
                    even=(layer_index % 2 == 0)
                ),
                transform_net_create_fn=lambda in_features, out_features: nflows.nn.nets.resnet.ResidualNet(
                    in_features=in_features,
                    out_features=out_features,
                    hidden_features=self.n_hidden,
                    num_blocks=self.n_blocks_per_layer,
                    use_batch_norm=True,
                    dropout_probability=self.dropout_prob,
                    activation=self.activation
                ),
                tails='linear',
                tail_bound=5,
                num_bins=self.n_bins,
                apply_unconditional_transform=False
            )
            transforms.append(layer)
        transform = nflows.transforms.base.CompositeTransform(transforms)
        self.flow = nflows.flows.base.Flow(transform, self.base_dist).to(self.device)

    def train_flow(self,
                   x: torch.Tensor,
                   n_epochs: int = 1000,
                   weights: torch.Tensor = None,
                   use_tqdm: bool = False,
                   x_val: Optional[torch.Tensor] = None):
        """
        Train flow using maximum likelihood estimation.

        :param x: training data with shape (n, d).
        :param n_epochs: number of training epochs for RealNVP.
        :param weights: non-negative weights for samples x.
        """
        if weights is None:
            weights = torch.ones(len(x), device=x.device)
        if len(weights) != len(x):
            raise ValueError(f"x and weights should match in dimension 0 but got sizes {len(x)} and {len(weights)}.")
        if torch.any(weights < 0):
            raise ValueError("Weights should be non-negative.")
        if n_epochs < 0:
            raise ValueError("Number of epochs should be non-negative.")

        _x = x.to(self.device)
        _weights = weights.to(self.device) if weights is not None else None
        _x_val = x_val.to(self.device) if x_val is not None else None

        self.flow.train()
        optimizer = self.get_optimizers()[0]
        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for _ in pbar:
            optimizer.zero_grad()
            loss = -torch.sum(_weights * self.logq(_x)) / torch.sum(_weights)
            loss.backward()
            optimizer.step()
            self.add_debug_data(x_train=_x, x_val=_x_val)
            self.debug_step()
        self.flow.eval()
        self.debug_animate()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Sample points in data space from MAF.
        """
        return self.flow.sample(num_samples=n_samples)

    def logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Evaluate the log density at the supplied data space locations.
        """
        return self.flow.log_prob(x.to(self.device))

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        z, logj_forward = self.flow._transform(x.to(self.device))
        return z, logj_forward

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        x, logj_backward = self.flow._transform.inverse(z.to(self.device))
        return x, logj_backward

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        # implemented like forward and _cascade methods of nflows.transforms.base.CompositeTransform.
        paths = [torch.clone(x.to(self.device))]
        for layer in self.flow._transform._transforms:
            paths.append(layer(paths[-1], None)[0])  # context = None
        paths = torch.stack(paths)
        return paths

    def inverse_paths(self, x: torch.Tensor) -> torch.Tensor:
        # implemented like inverse, and _cascade methods of nflows.transforms.base.CompositeTransform.
        paths = [torch.clone(x.to(self.device))]
        for layer in [t.inverse for t in self.flow._transform._transforms[::-1]]:
            paths.append(layer(paths[-1], None)[0])  # context = None
        paths = torch.stack(paths)
        return paths

    def get_optimizers(self):
        return [optim.Adam(self.flow.parameters(), **self.optimizer_kwargs)]


class RealNVPInterface(TorchFlowInterface):
    def __init__(self,
                 n_dim: int,
                 n_layers: int = 5,
                 n_hidden: int = 5,
                 n_blocks_per_layer: int = 2,
                 activation: callable = torch.sigmoid,
                 dropout_prob: float = 0.0,
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        super().__init__(n_dim=n_dim, debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)
        self.n_dim = n_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_blocks_per_layer = n_blocks_per_layer
        self.activation = activation
        self.dropout_prob = dropout_prob
        self.base_dist = nflows.distributions.normal.StandardNormal(shape=[self.n_dim])
        self.flow: Optional[nflows.flows.base.Flow] = None

    def create_flow(self, *args, **kwargs):
        self.flow = nflows.flows.realnvp.SimpleRealNVP(
            features=self.n_dim,
            hidden_features=self.n_hidden,
            num_layers=self.n_layers,
            activation=self.activation,
            dropout_probability=self.dropout_prob,
            num_blocks_per_layer=self.n_blocks_per_layer
        ).to(self.device)

    def train_flow(self,
                   x: torch.Tensor,
                   n_epochs: int = 1000,
                   weights: torch.Tensor = None,
                   use_tqdm: bool = False,
                   x_val: Optional[torch.Tensor] = None):
        """
        Train RealNVP using maximum likelihood estimation.

        :param x: training data with shape (n, d).
        :param n_epochs: number of training epochs for RealNVP.
        :param weights: non-negative weights for samples x.
        """
        if weights is None:
            weights = torch.ones(len(x), device=x.device)
        if len(weights) != len(x):
            raise ValueError(f"x and weights should match in dimension 0 but got sizes {len(x)} and {len(weights)}.")
        if torch.any(weights < 0):
            raise ValueError("Weights should be non-negative.")
        if n_epochs < 0:
            raise ValueError("Number of epochs should be non-negative.")

        _x = x.to(self.device)
        _weights = weights.to(self.device) if weights is not None else None
        _x_val = x_val.to(self.device) if x_val is not None else None

        self.flow.train()
        optimizer = self.get_optimizers()[0]
        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for _ in pbar:
            optimizer.zero_grad()
            loss = -torch.sum(_weights * self.logq(_x)) / torch.sum(_weights)
            loss.backward()
            optimizer.step()
            self.add_debug_data(x_train=_x, x_val=_x_val)
            self.debug_step()
        self.flow.eval()
        self.debug_animate()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Sample points in data space from MAF.
        """
        return self.flow.sample(num_samples=n_samples)

    def logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Evaluate the log density at the supplied data space locations.
        """
        return self.flow.log_prob(x.to(self.device))

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        z, logj_forward = self.flow._transform(x.to(self.device))
        return z, logj_forward

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        x, logj_backward = self.flow._transform.inverse(z.to(self.device))
        return x, logj_backward

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        # implemented like forward and _cascade methods of nflows.transforms.base.CompositeTransform.
        paths = [torch.clone(x.to(self.device))]
        for layer in self.flow._transform._transforms:
            paths.append(layer(paths[-1], None)[0])  # context = None
        paths = torch.stack(paths)
        return paths

    def inverse_paths(self, x: torch.Tensor) -> torch.Tensor:
        # implemented like inverse, and _cascade methods of nflows.transforms.base.CompositeTransform.
        paths = [torch.clone(x.to(self.device))]
        for layer in [t.inverse for t in self.flow._transform._transforms[::-1]]:
            paths.append(layer(paths[-1], None)[0])  # context = None
        paths = torch.stack(paths)
        return paths

    def get_optimizers(self):
        return [optim.Adam(self.flow.parameters(), **self.optimizer_kwargs)]


class MAFInterface(TorchFlowInterface):
    def __init__(self,
                 n_dim: int,
                 context_dim: int = None,
                 n_layers: int = 5,
                 n_hidden: int = 5,
                 activation: callable = torch.sigmoid,
                 dropout_prob: float = 0.0,
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        """
        MAF interface.

        :param n_dim: number of dimensions in the data.
        :param n_layers: number of MADE layers.
        :param n_hidden: number of hidden units in MADE layers.
        :param activation: activation function for MADE.
        :param dropout_prob: dropout probability for MADE.
        """

        super().__init__(n_dim=n_dim, debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)
        self.n_dim = n_dim
        self.context_dim = context_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.activation = activation
        self.dropout_prob = dropout_prob
        self.base_dist = nflows.distributions.normal.StandardNormal(shape=[self.n_dim])
        self.flow: Optional[nflows.flows.base.Flow] = None

    def create_flow(self, *args, **kwargs):
        """
        Create a MAF object.
        """
        transforms = []
        for _ in range(self.n_layers):
            layer_permutation = nflows.transforms.permutations.ReversePermutation(features=self.n_dim)
            layer_maf_transform = nflows.flows.autoregressive.MaskedAffineAutoregressiveTransform(
                features=self.n_dim,
                hidden_features=self.n_hidden,
                activation=self.activation,
                dropout_probability=self.dropout_prob,
                context_features=self.context_dim
            )
            transforms.append(layer_permutation)
            transforms.append(layer_maf_transform)
        maf_transform = nflows.transforms.base.CompositeTransform(transforms)
        self.flow = nflows.flows.base.Flow(maf_transform, self.base_dist).to(self.device)

    def laplace_prior(self, scale=1.0):
        # Impose a Laplace prior with parameter b equal to scale over the flow weights.
        # This is equivalent to L1 regularization.
        # Increasing scale means less regularization.
        total = 0.0
        for layer in self.flow._transform._transforms:
            if isinstance(layer, MaskedAffineAutoregressiveTransform):
                for parameter_name, parameter in layer.autoregressive_net.named_parameters():
                    if parameter_name.endswith('weight'):
                        # Regularize weights, but not biases
                        total += parameter.abs().sum()  # Laplace prior

        return -total / scale

    def gaussian_prior(self, scale=1.0):
        # Impose a Gaussian prior with standard deviation equal to scale over the flow weights.
        # This is equivalent to L2 regularization.
        # Increasing scale means less regularization.
        total = 0.0
        for layer in self.flow._transform._transforms:
            if isinstance(layer, MaskedAffineAutoregressiveTransform):
                for parameter_name, parameter in layer.autoregressive_net.named_parameters():
                    if parameter_name.endswith('weight'):
                        # Regularize weights, but not biases
                        total += parameter.square().sum()  # Gaussian prior
        return -total / (2 * (scale ** 2))

    def compute_loss(self, x: torch.Tensor, weights: torch.Tensor, context_x: torch.Tensor = None) -> torch.Tensor:
        # Maximum a posteriori (maximum likelihood with regularization)

        # Log likelihood: weighted KL divergence
        log_likelihood = torch.sum(self.logq(x, context_x))

        # Log prior: Laplace on weights, improper uniform (ignored) on everything else
        log_prior = self.laplace_prior(scale=0.1)

        loss = -(log_likelihood + log_prior)
        return loss

    def train_flow(self,
                   x: torch.Tensor,
                   n_epochs: int = 1000,
                   weights: torch.Tensor = None,
                   use_tqdm: bool = False,
                   context_x: torch.Tensor = None,
                   val_frac: float = 0.0,
                   early_stopping_patience: int = 50,
                   return_info: bool = False):
        """
        Train MAF.

        :param x: training data with shape (n, d).
        :param n_epochs: number of training epochs for MAF.
        :param weights: non-negative weights for samples x.
        :param use_tqdm: display progress bar.
        :param context_x: context data for conditional MAF.
        :param val_frac: fraction of data to be designated as a validation set. Validation samples chosen randomly.
        :param early_stopping_patience: after this number of epochs with no validation loss decrease, stop training.
        :param return_info: if True, return a dictionary with training info.
        """
        if weights is None:
            weights = torch.ones(len(x), device=self.device) / len(x)
        if len(weights) != len(x):
            raise ValueError(f"x and weights should match in dimension 0 but got sizes {len(x)} and {len(weights)}.")
        if torch.any(weights < 0):
            raise ValueError("Weights should be non-negative.")
        if n_epochs < 0:
            raise ValueError("Number of epochs should be non-negative.")
        if not 0 <= val_frac <= 1:
            raise ValueError("Fraction of validation data must be between 0 and 1")

        n_val = int(len(x) * val_frac)
        permutation = torch.randperm(len(x))

        x_train = x[permutation[n_val:]].to(self.device)
        weights_train = weights[permutation[n_val:]].to(self.device)
        context_train = None if context_x is None else context_x[permutation[n_val:]].to(self.device)

        x_val = x[permutation[:n_val]].to(self.device)
        weights_val = weights[permutation[:n_val]].to(self.device)
        context_val = None if context_x is None else context_x[permutation[:n_val]].to(self.device)

        best_epoch = 0
        best_state = deepcopy(self.flow.state_dict())
        best_val_loss = torch.inf

        training_info = {
            'train loss': torch.inf,
            'val loss': torch.inf,
            'best val loss': torch.inf,
            'best epoch': 0,
            'epochs elapsed': 0
        }

        self.flow.train()
        self.flow.requires_grad_(True)
        optimizer = self.get_optimizers()[0]
        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for epoch in pbar:
            optimizer.zero_grad()
            loss = self.compute_loss(x_train, weights_train, context_train)
            loss.backward()
            optimizer.step()
            self.add_debug_data(x_train=x_train, x_val=x_val)
            self.debug_step()

            training_info['train loss'] = float(loss)
            training_info['epochs elapsed'] += 1

            if val_frac > 0:
                with torch.no_grad():
                    val_loss = self.compute_loss(x_val, weights_val, context_val)
                    training_info['val loss'] = float(val_loss)
                    if val_loss < best_val_loss:
                        best_epoch = epoch
                        best_val_loss = val_loss
                        best_state = deepcopy(self.flow.state_dict())
                        training_info['best epoch'] = best_epoch + 1
                        training_info['best val loss'] = float(val_loss)
                    elif epoch - best_epoch >= early_stopping_patience:
                        break

        if val_frac > 0:
            self.flow.load_state_dict(best_state)

        self.flow.eval()
        self.flow.requires_grad_(False)
        self.debug_animate()

        training_info['latent space error percentage'] = check_latent_space_valid(self.forward(x_train), warn=False)

        if return_info:
            return training_info

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Sample points in data space from MAF.
        """
        return self.flow.sample(num_samples=n_samples)

    def logq(self, x: torch.Tensor, context_x: torch.Tensor = None, **kwargs) -> torch.Tensor:
        """
        Evaluate the MAF density at the supplied data space locations.
        """
        return self.flow.log_prob(x.to(self.device), context=context_x)

    def forward_with_logj(self, x: torch.Tensor, context_x: torch.Tensor = None, **kwargs):
        x_bound, logj_forward_bound = x, torch.zeros(len(x), device=self.device)
        z, logj_forward = self.flow._transform(x_bound.to(self.device), context=context_x)
        logj_forward += logj_forward_bound
        return z, logj_forward

    def inverse_with_logj(self, z: torch.Tensor, context_x: torch.Tensor = None, **kwargs):
        x, logj_backward = self.flow._transform.inverse(z.to(self.device), context=context_x)
        return x, logj_backward

    def forward_paths(self, x: torch.Tensor, context_x: torch.Tensor = None) -> torch.Tensor:
        # implemented like forward and _cascade methods of nflows.transforms.base.CompositeTransform.
        paths = [torch.clone(x.to(self.device))]
        for layer in self.flow._transform._transforms:
            paths.append(layer(paths[-1], context_x)[0])
        paths = torch.stack(paths)
        return paths

    def inverse_paths(self, x: torch.Tensor, context_x: torch.Tensor = None) -> torch.Tensor:
        # implemented like inverse, and _cascade methods of nflows.transforms.base.CompositeTransform.
        paths = [torch.clone(x.to(self.device))]
        for layer in [t.inverse for t in self.flow._transform._transforms[::-1]]:
            paths.append(layer(paths[-1], context_x)[0])  # context = None
        paths = torch.stack(paths)
        return paths

    def get_optimizers(self):
        return [optim.Adam(self.flow.parameters(), **self.optimizer_kwargs)]


class HierarchicalMAFInterface(TorchFlowInterface):
    def __init__(self,
                 rv_mask: torch.Tensor,
                 dag_edges,
                 n_layers: int = 5,
                 n_hidden: int = 5,
                 activation: callable = torch.sigmoid,
                 device: torch.device = torch.device('cpu'),
                 debugger: Optional[FlowDebugger] = None,
                 optimizer_kwargs: dict = None):
        n_dim = len(rv_mask)
        super().__init__(debugger, device, optimizer_kwargs, n_dim)

        self.rv_mask = rv_mask
        self.dag_edges = dag_edges
        self.rv_groups, self.rv_group_dims = torch.unique(self.rv_mask, return_counts=True)
        self.rv_groups = list(map(int, self.rv_groups.numpy()))
        self.rv_group_dims = list(map(int, self.rv_group_dims.numpy()))

        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.activation = activation

        self.interfaces: Dict[int, MAFInterface] = {
            rv_id: MAFInterface(
                n_dim=int(self.rv_group_dims[i]),
                context_dim=int(sum(self.context_mask(rv_id))),
                n_layers=self.n_layers,
                n_hidden=int(self.n_hidden),
                activation=self.activation,
                debugger=None,  # No support yet
                device=device,
                optimizer_kwargs=optimizer_kwargs
            ) for i, rv_id in enumerate(self.rv_groups)
        }

    def get_optimizers(self):
        raise NotImplementedError

    def context_mask(self, rv_id):
        # Get a mask that determines which dimensions are the context for rv_id
        return torch.isin(self.rv_mask, torch.tensor([e0 for (e0, e1) in self.dag_edges if e1 == rv_id]))

    def create_flow(self, *args, **kwargs):
        for rv_id in self.rv_groups:
            self.interfaces[rv_id].create_flow(*args, **kwargs)

    def train_flow(self, x, *args, **kwargs):
        for rv_id in self.rv_groups:
            # TODO pass in context here
            self.interfaces[rv_id].train_flow(x[:, self.rv_mask == rv_id], *args, **kwargs)
        return dict()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        output = torch.zeros(n_samples, self.n_dim, device=self.device)
        for rv_id in self.rv_groups:
            output[:, self.rv_mask == rv_id] = self.interfaces[rv_id].sample(n_samples)  # TODO pass in context here
        return output

    def logq(self, x: torch.Tensor) -> torch.Tensor:
        log_prob_per_sample = torch.stack([
            self.interfaces[rv_id].logq(x[:, self.rv_mask == rv_id]) for rv_id in self.rv_groups  # TODO context here
        ]).sum(dim=0)
        return log_prob_per_sample

    def forward_with_logj(self, x: torch.Tensor, context_x: torch.Tensor = None, **kwargs):
        z_output = torch.zeros_like(x)
        logj_output = torch.zeros(len(x))
        for rv_id in self.rv_groups:
            z, logj_forward = self.interfaces[rv_id].forward_with_logj(
                x[:, self.rv_mask == rv_id],
                context=context_x,
                **kwargs
            )
            z_output[:, self.rv_mask == rv_id] = z
            logj_output += logj_forward
        return z_output, logj_output

    def inverse_with_logj(self, z: torch.Tensor, context_x: torch.Tensor = None, **kwargs):
        x_output = torch.zeros_like(z)
        logj_output = torch.zeros(len(z))
        for rv_id in self.rv_groups:
            x, logj_backward = self.interfaces[rv_id].inverse_with_logj(
                z[:, self.rv_mask == rv_id],
                context=context_x,
                **kwargs
            )
            x_output[:, self.rv_mask == rv_id] = x
            logj_output += logj_backward
        return x_output, logj_output

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def inverse_paths(self, z: torch.Tensor):
        raise NotImplementedError


class LinearGaussianInterface(TorchFlowInterface):
    def __init__(self, n_dim: int):
        super().__init__(n_dim=n_dim)
        # This is just an affine transformation, (determined by mu, std) that maps to a standard normal
        # In other words, standardization.
        self.mean = torch.zeros(n_dim)
        self.std = torch.ones(n_dim)

    def get_optimizers(self):
        return []

    def create_flow(self, *args, **kwargs):
        pass

    def train_flow(self, x: torch.Tensor, *args, **kwargs):
        # Unbiased estimators for the mean and standard deviation
        self.mean = x.mean(dim=0)
        self.std = x.std(dim=0)
        return dict()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        z = torch.randn(n_samples, self.n_dim)
        x = self.inverse(z)
        return x

    def logq(self, x: torch.Tensor) -> torch.Tensor:
        z, logj = self.forward_with_logj(x)
        return gaussian_log_prob(z).sum(dim=1) + logj

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        z = (x - self.mean.view(1, -1)) / self.std.view(1, -1)
        logj = torch.zeros(len(x), device=x.device) - torch.log(self.std).sum()
        return z, logj

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        x = z * self.std.view(1, -1) + self.mean.view(1, -1)
        logj = torch.zeros(len(z), device=z.device) + torch.log(self.std).sum()
        return x, logj

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def inverse_paths(self, z: torch.Tensor):
        raise NotImplementedError


class TRENFInterface(TorchFlowInterface):
    def __init__(self,
                 device: torch.device = torch.device('cpu'),
                 debugger: Optional[FlowDebugger] = None,
                 optimizer_kwargs: dict = None):
        """
        TRENF interface.

        """

        super().__init__(debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)
        self.flow: Optional[nfmc_jax.TRENF.TRENF.NF] = None

    def create_flow(self, ndim: int, layers: list, **kwargs):
        """
        Create a TRENF object.
        """
        self.n_dim = ndim
        self.flow = NF(ndim, layers, Nconditional=0).to(self.device).requires_grad_(False)

    def train_flow(self,
                   x: torch.Tensor,
                   n_epochs: int = 1000,
                   lr: float = 1e-3,
                   batch_size: Optional[int] = None,
                   weights: torch.Tensor = None,
                   use_tqdm: bool = False,
                   val_frac: float = 0,
                   verbose: bool = False,
                   **kwargs):
        """
        Train TRENF using maximum likelihood estimation.

        :param x: training data with shape (n, d).
        :param n_epochs: number of training epochs for TRENF.
        :param weights: non-negative weights for samples x.
        """
        if weights is None:
            weights = torch.ones(len(x), device=x.device)
        if len(weights) != len(x):
            raise ValueError(f"x and weights should match in dimension 0 but got sizes {len(x)} and {len(weights)}.")
        if torch.any(weights < 0):
            raise ValueError("Weights should be non-negative.")
        if n_epochs < 0:
            raise ValueError("Number of epochs should be non-negative.")

        self.flow.requires_grad_(True).train()
        optimizer = self.get_optimizers(lr=lr)[0]
        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        if val_frac:
            perm = torch.randperm(len(x))
            xp = x[perm]
            weightsp = weights[perm]
            N_train = int((1 - val_frac) * len(x))
            x_train = xp[:N_train]
            weights_train = weightsp[:N_train]
            x_val = xp[N_train:]
            weights_val = weightsp[N_train:]
            if batch_size is None:
                batch_size = N_train
            trainset = torch.utils.data.TensorDataset(x_train, weights_train)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
            valset = torch.utils.data.TensorDataset(x_val, weights_val)
            valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, drop_last=False)
        else:
            if batch_size is None:
                batch_size = len(x)
            x_val = None
            trainset = torch.utils.data.TensorDataset(x, weights)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)
        best_loss = float('inf')
        for _ in pbar:
            total_loss = 0
            ntotal = 0
            for i, data in enumerate(trainloader, 0):
                data, weight = data
                data = data.to(self.device)
                weight = weight.to(self.device)
                optimizer.zero_grad()
                loss = -torch.sum(weight * self.flow.evaluate_density(data)) / torch.sum(weight)
                loss.backward()
                optimizer.step()
                total_loss += torch.sum(weight).item() * loss.item()
                ntotal += torch.sum(weight).item()
            loss_train = total_loss / ntotal
            if val_frac:
                total_loss_val = 0
                ntotal_val = 0
                with torch.no_grad():
                    for i, data in enumerate(valloader, 0):
                        data, weight = data
                        data = data.to(self.device)
                        weight = weight.to(self.device)
                        total_loss_val += -torch.sum(weight * self.flow.evaluate_density(data)).item()
                        ntotal_val += torch.sum(weight).item()
                    loss_val = total_loss_val / ntotal_val
                    if loss_val < best_loss:
                        best_loss = loss_val
                        best_state = deepcopy(self.flow.state_dict())
                if verbose:
                    print('Epoch:', _, 'train loss:', loss_train, 'validate loss:', loss_val)
            else:
                if loss_train < best_loss:
                    best_loss = loss_train
                    best_state = deepcopy(self.flow.state_dict())
                if verbose:
                    print('Epoch:', _, 'loss:', loss_train)
            self.add_debug_data(x_train=x, x_val=x_val)
            self.debug_step()
        self.flow.load_state_dict(best_state)
        self.flow.requires_grad_(False).eval()
        self.debug_animate()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Sample points in data space from TRENF.
        """
        return self.sample_with_logq(n_samples)[0]

    def logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Evaluate the TRENF density at the supplied data space locations.
        """
        return self.flow.evaluate_density(x)

    def forward_with_logj(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Push points x to latent space.

        :param x: points in data space with shape (n, d).
        :return: points in latent space with shape (n, d).
        """
        z, logj = self.flow.forward(x)
        return z, logj

    def inverse_with_logj(self, z: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Push points z to data space.

        :param z: points in latent space with shape (n, d).
        :return: points in data space with shape (n, d).
        """
        x, logj = self.flow.inverse(z)
        return x, -logj

    @torch.no_grad()
    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute forward paths for given data. This means having the original data and all intermediate data points,
        given by forward TRENF layer transformations.

        Returns: torch.Tensor with shape (n_layers + 1, x.shape[0], x.shape[1]). The first index corresponds to original
        data points, the last index corresponds to latent representations.
        """
        paths = [torch.clone(x)]
        for layer in self.flow.layer:
            paths.append(layer(paths[-1])[0])
        paths = torch.stack(paths).cpu()
        return paths

    @torch.no_grad()
    def inverse_paths(self, z: torch.Tensor):
        """
        Compute inverse paths for given latent points. This means having the latent data and all intermediate data
        points, given by inverse TRENF layer transformations.

        Returns: torch.Tensor with shape (n_layers + 1, z.shape[0], z.shape[1]). The first index corresponds to latent
        data points, the last index corresponds to data space representations.
        """
        paths = [torch.clone(z)]
        for layer in self.flow.layer[::-1]:
            paths.append(layer.inverse(paths[-1])[0])
        paths = torch.stack(paths).cpu()
        return paths

    def get_optimizers(self, lr: float = 1e-3):
        return [optim.Adam(self.flow.parameters(), lr=lr)]


class SINFInterface(TorchFlowInterface):
    def __init__(self,
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        """
        SINF interface.
        TODO add documentation.

        :param device: torch device to use for SINF.
        """
        super().__init__(debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)

        self.device = device
        self.flow: Optional[nfmc_jax.sinf.SINF.SINF] = None
        self.gis_kwargs = dict(verbose=False)

    def create_flow(self, x: torch.Tensor, weights: torch.Tensor = None, val_frac: float = 0.0, **kwargs):
        """
        Create the SINF model.
        """
        for key in kwargs:
            self.gis_kwargs[key] = kwargs[key]
        if self.device == torch.device('cpu'):
            self.gis_kwargs['nocuda'] = True

        self.n_dim = x.shape[1]

        perm = torch.randperm(len(x))
        xp = x[perm]
        Ntrain = int((1.0 - val_frac) * len(x))
        x_train = xp[:Ntrain]
        if weights is not None:
            wp = weights[perm]
            weight_train = wp[:Ntrain]
        else:
            weight_train = None
        if val_frac > 0.0:
            x_val = xp[Ntrain:]
            if weights is not None:
                weight_val = wp[Ntrain:]
        else:
            x_val = None
            weight_val = None

        # We need to deepcopy the inputs, because SINF modifies training data in place.
        self.flow = nfmc_jax.sinf.GIS.GIS(
            data_train=deepcopy(x_train.to(self.device)),
            weight_train=deepcopy(weight_train.to(self.device)) if weights is not None else None,
            data_validate=deepcopy(x_val.to(self.device)) if val_frac > 0.0 else None,
            weight_validate=deepcopy(weight_val.to(self.device)) if val_frac > 0.0 and weights is not None else None,
            **self.gis_kwargs
        )

    def train_flow(self, x: torch.Tensor, weights: torch.Tensor = None, val_frac: float = 0.0, **kwargs):
        """
        Retrain GIS.

        :param x: training data with shape (n_train, d).
        :param weights: non-negative weights for samples x.
        :param val_frac: Fraction of samples to use for validation.
        """
        _x = x.to(self.device)
        _weights = weights.to(self.device) if weights is not None else weights
        self.create_flow(_x, _weights, val_frac, **kwargs)

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        """
        Sample points in data space from MAF.
        """
        return self.flow.sample(nsample=n_samples, device=self.device)[0]

    def logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Evaluate SINF density at supplied data space locations.
        """
        return self.flow.evaluate_density(x.to(self.device))

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        z, logj_forward = self.flow.forward(data=x.to(self.device))
        return z, logj_forward

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        x, logj_backward = self.flow.inverse(data=z.to(self.device))
        logj_backward = -logj_backward  # Correction due to the implementation of SINF
        return x, logj_backward

    @torch.no_grad()
    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute forward paths for given data. This means having the original data and all intermediate data points,
        given by forward SINF layer transformations.

        Returns: torch.Tensor with shape (n_layers + 1, x.shape[0], x.shape[1]). The first index corresponds to original
        data points, the last index corresponds to latent representations.
        """
        paths = [torch.clone(x.to(self.device))]
        for layer in self.flow.layer:
            paths.append(layer(paths[-1])[0])
        paths = torch.stack(paths)
        return paths

    @torch.no_grad()
    def inverse_paths(self, z: torch.Tensor):
        """
        Compute inverse paths for given latent points. This means having the latent data and all intermediate data
        points, given by inverse SINF layer transformations.

        Returns: torch.Tensor with shape (n_layers + 1, z.shape[0], z.shape[1]). The first index corresponds to latent
        data points, the last index corresponds to data space representations.
        """
        paths = [torch.clone(z.to(self.device))]
        for layer in self.flow.layer[::-1]:
            paths.append(layer.inverse(paths[-1])[0])
        paths = torch.stack(paths)
        return paths

    def get_optimizers(self):
        return []


class SNFInterface(SINFInterface):
    def __init__(self,
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        """
        SNF interface.
        Uses GIS as initialization.
        TODO add documentation.

        :param device: torch device to use for the flow.
        """
        super().__init__(device=device, debugger=debugger, optimizer_kwargs=optimizer_kwargs)

    def create_flow(self, x, *args, random_init: bool = False, **kwargs):
        if random_init:
            # Random initialization
            assert 'iteration' in kwargs
            n_dim = x.shape[1]
            self.flow = nfmc_jax.sinf.SINF.SINF(n_dim)
            for _ in range(kwargs['iteration']):
                self.flow.layer.append(SlicedTransport(ndim=n_dim))
        else:
            # SINF initialization
            super().create_flow(x, *args, **kwargs)

    def train_flow(self,
                   x: torch.Tensor,
                   weights: torch.Tensor = None,
                   x_val: torch.Tensor = None,
                   weights_val: torch.Tensor = None,
                   n_epochs: int = 1000,
                   lr_Psi: float = 1e-2,
                   lr_A: float = 2e-5,
                   lr_Psi_decay: float = 1.0,
                   lr_A_decay: float = 1.0,
                   reg: float = 0.0,
                   reg1: float = 0.0,
                   reg2: float = 0.0,
                   use_tqdm: bool = False,
                   **kwargs):
        """

        :param x:
        :param weights:
        :param x_val:
        :param weights_val:
        :param n_epochs:
        :param lr_Psi:
        :param lr_A:
        :param lr_Psi_decay: exponential learning rate decay parameter for Psi. 1 == no change, 0.99 == slight change.
        :param lr_A_decay: exponential learning rate decay parameter for A. 1 == no change, 0.99 == slight change.
        :param reg:
        :param reg1:
        :param reg2:
        :param use_tqdm:
        """

        self.flow.train()
        self.flow.requires_grad_(True)

        optimizer_A, optimizer_Psi = self.get_optimizers(lr_A, lr_Psi)
        lr_scheduler_A = optim.lr_scheduler.ExponentialLR(optimizer=optimizer_A, gamma=lr_A_decay)
        lr_scheduler_Psi = optim.lr_scheduler.ExponentialLR(optimizer=optimizer_Psi, gamma=lr_Psi_decay)

        _x = x.to(self.device)
        _weights = weights.to(self.device) if weights is not None else None
        _x_val = x_val.to(self.device) if x_val is not None else None
        _weights_val = weights_val.to(self.device) if weights_val is not None else None

        best_loss = np.infty
        best_state = deepcopy(self.flow.state_dict())
        for _ in (pbar := (tqdm(range(n_epochs)) if use_tqdm else range(n_epochs))):
            optimizer_Psi.zero_grad()
            optimizer_A.zero_grad()

            loss, _ = lossfunc3(model=self.flow, x=_x, weight=_weights)
            if reg > 0 or reg1 > 0 or reg2 > 0:
                loss = loss + regularization(self.flow, reg, reg1, reg2)

            loss.backward()

            optimizer_Psi.step()
            optimizer_A.step()
            lr_scheduler_A.step()
            lr_scheduler_Psi.step()

            if _x_val is not None:
                with torch.no_grad():
                    val_loss, _ = lossfunc3(model=self.flow, x=_x_val, weight=_weights_val)
                    if val_loss < best_loss:
                        best_loss = val_loss
                        best_state = deepcopy(self.flow.state_dict())

            self.add_debug_data(x_train=_x, x_val=_x_val)
            self.debug_step()

            if use_tqdm:
                pbar.set_postfix(train_loss=float(loss))

            self._auto_step += 1

        if x_val is not None:
            self.flow.load_state_dict(best_state)
        self.flow.eval()
        self.flow.requires_grad_(False)
        self.debug_animate()

    def get_optimizers(self, lr_A=2e-5, lr_Psi=1e-2):
        group_A = []
        group_Psi = []
        for layer in self.flow.layer:
            if isinstance(layer, SlicedTransport):
                group_A.append(layer.A)
                for param in layer.transform1D.parameters():
                    group_Psi.append(param)
            elif isinstance(layer, whiten):
                group_A.append(layer.E)
                group_Psi.append(layer.mean)
                group_Psi.append(layer.D)

        optimizer_A = Stiefel_SGD(group_A, lr=lr_A, momentum=0.9)
        optimizer_Psi = optim.Adam(group_Psi, lr=lr_Psi)

        return [optimizer_A, optimizer_Psi]


class SimplifiedSNFInterface(SNFInterface):
    def __init__(self, *args, **kwargs):
        """
        SNFInterface where all rotation matrices are identities. In other words, we only have splines.
        """
        super().__init__(*args, **kwargs)

    def set_rotation_identities(self):
        """
        Set all rotation matrices to identities.
        """
        for i in range(len(self.flow.layer)):
            if isinstance(self.flow.layer[i], SlicedTransport):
                self.flow.layer[i].A.data = torch.eye(n=self.flow.layer[i].A.shape[0])

    def create_flow(self, *args, **kwargs):
        super().create_flow(*args, **kwargs)
        self.set_rotation_identities()

    def train_flow(self,
                   x: torch.Tensor,
                   weights: torch.Tensor = None,
                   x_val: torch.Tensor = None,
                   weights_val: torch.Tensor = None,
                   n_epochs: int = 1000,
                   lr_Psi: float = 1e-2,
                   lr_Psi_decay: float = 0.0,
                   reg: float = 0.0,
                   reg1: float = 0.0,
                   reg2: float = 0.0,
                   use_tqdm: bool = False,
                   **kwargs):
        """
        :param x:
        :param weights:
        :param x_val:
        :param weights_val:
        :param n_epochs:
        :param lr_Psi:
        :param lr_Psi_decay: exponential learning rate decay parameter for Psi. 1 == no change, 0.99 == slight change.
        :param reg:
        :param reg1:
        :param reg2:
        :param use_tqdm:
        """

        self.flow.train()
        self.flow.requires_grad_(True)

        optimizer_Psi = self.get_optimizers(lr_Psi)[0]
        lr_scheduler_Psi = optim.lr_scheduler.ExponentialLR(optimizer=optimizer_Psi, gamma=lr_Psi_decay)

        _x = x.to(self.device)
        _weights = weights.to(self.device) if weights is not None else None
        _x_val = x_val.to(self.device) if x_val is not None else None
        _weights_val = weights_val.to(self.device) if weights_val is not None else None

        best_loss = np.infty
        best_state = deepcopy(self.flow.state_dict())
        for _ in (pbar := (tqdm(range(n_epochs)) if use_tqdm else range(n_epochs))):
            optimizer_Psi.zero_grad()

            loss, _ = lossfunc3(model=self.flow, x=_x, weight=_weights)
            if reg > 0 or reg1 > 0 or reg2 > 0:
                loss = loss + regularization(self.flow, reg, reg1, reg2)

            loss.backward()

            optimizer_Psi.step()
            lr_scheduler_Psi.step()

            if _x_val is not None:
                with torch.no_grad():
                    val_loss, _ = lossfunc3(model=self.flow, x=_x_val, weight=_weights_val)
                    if val_loss < best_loss:
                        best_loss = val_loss
                        best_state = deepcopy(self.flow.state_dict())

            self.add_debug_data(x_train=_x, x_val=_x_val)
            self.debug_step()

            if use_tqdm:
                pbar.set_postfix(train_loss=float(loss))

            self._auto_step += 1

        if x_val is not None:
            self.flow.load_state_dict(best_state)
        self.flow.eval()
        self.flow.requires_grad_(False)
        self.debug_animate()

    def get_optimizers(self, lr_Psi=1e-2, **kwargs):
        group_Psi = []
        for layer in self.flow.layer:
            if isinstance(layer, SlicedTransport):
                for param in layer.transform1D.parameters():
                    group_Psi.append(param)
            elif isinstance(layer, whiten):
                group_Psi.append(layer.mean)
                group_Psi.append(layer.D)

        optimizer_Psi = optim.Adam(group_Psi, lr=lr_Psi)
        return [optimizer_Psi]


class PPInterface(FlowInterface):
    def __init__(self, x_ref: torch.Tensor, sigma: torch.Tensor, **kwargs):
        """
        This interface is not based on a flow, so it may not function in all settings. It is meant to be used in DLA.
        :param x_ref: reference sample used to define the PP KDE
        :param sigma: tensor of KDE scale parameters for each dimension
        """
        super().__init__()
        self.x_ref = x_ref
        self.sigma = sigma

    def create_flow(self, *args, **kwargs):
        raise NotImplementedError

    def train_flow(self, *args, **kwargs):
        pass

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def logq(self, x: torch.Tensor) -> torch.Tensor:
        n_dim = x.shape[1]
        log_kde = torch.zeros(x.shape[0])
        gauss_kernel = lambda val, samples: -0.5 * n_dim * np.log(2.0 * np.pi) - 0.5 * torch.log(
            torch.prod(self.sigma ** 2)) - 0.5 * torch.sum((val - samples) ** 2 / self.sigma ** 2, dim=1)
        for i, sample in enumerate(x):
            log_kde[i] = torch.logsumexp(gauss_kernel(sample, self.x_ref), dim=0) - np.log(self.x_ref.shape[0])

        return log_kde

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        raise NotImplementedError

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        raise NotImplementedError

    @torch.no_grad()
    def grad_x_logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        n_particles = x.shape[0]
        grad_num = torch.zeros_like(x)
        grad_den = torch.zeros_like(x)

        gauss_kernel = lambda val, samples: -0.5 * torch.sum((val - samples) ** 2 / self.sigma ** 2, dim=1)
        for k in range(n_particles):
            grad_num[k] = torch.sum(
                torch.exp(gauss_kernel(x[k], self.x_ref)).reshape(self.x_ref.shape[0], 1) * (
                        x[k] - self.x_ref) / self.sigma ** 2, dim=0)
            grad_den[k] = torch.sum(torch.exp(gauss_kernel(x[k], self.x_ref)), dim=0)
        gradient = -grad_num / grad_den

        # Check the sign ...
        return gradient

    def grad_z_logp(self, z: torch.Tensor, grad_wrt_x, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def grad_z_logj(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def inverse_paths(self, z: torch.Tensor):
        raise NotImplementedError

    def get_optimizers(self):
        raise NotImplementedError


class GPInterface(TorchFlowInterface):
    def __init__(self,
                 gp_model: gpytorch.models.GP = gpytorch.models.ExactGP,
                 gp_likelihood: gpytorch.likelihoods.Likelihood = gpytorch.likelihoods.GaussianLikelihood(),
                 gp_mean: gpytorch.means.Mean = gpytorch.means.ConstantMean(),
                 gp_kernel: gpytorch.kernels.Kernel = gpytorch.kernels.ScaleKernel(
                     gpytorch.kernels.MaternKernel(nu=2.5)),
                 debugger: Optional[FlowDebugger] = None,
                 device: torch.device = torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        """
        GP interface. Used to surrogate target densities with a GP - not based on an NF.
        :param gp_model: gpytorch model object.
        :param gp_likelihood: gpytorch likelihood object.
        :param gp_mean: gpytorch mean object.
        :param gp_kernel: gpytorch kernel object.
        :param debugger: debugger object.
        :param device: torch device used for the GP.
        :param optimizer_kwargs: optional optimizer keyword arguments.
        """

        super().__init__(debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)
        self.gp_model = gp_model
        self.gp_likelihood = gp_likelihood
        self.gp_mean = gp_mean
        self.gp_kernel = gp_kernel
        self.gp_surrogate: Optional[gpytorch.models.GP] = None

    def create_flow(self,
                    samples: torch.Tensor,
                    logp: torch.Tensor,
                    *args,
                    **kwargs):
        """
        Create a GP object.
        """

        class gpytorchGPModel(self.gp_model):
            def __init__(self, train_x, train_y, likelihood, mean, kernel):
                super(gpytorchGPModel, self).__init__(train_x, train_y, likelihood)
                self.mean_module = mean
                self.covar_module = kernel

            def forward(self, x):
                mean_x = self.mean_module(x)
                covar_x = self.covar_module(x)
                return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

        _samples = samples.to(self.device)
        _logp = logp.to(self.device)
        self.gp_surrogate = gpytorchGPModel(_samples, _logp, self.gp_likelihood, self.gp_mean, self.gp_kernel)

    def train_flow(self,
                   samples: torch.Tensor,
                   logp: torch.Tensor,
                   n_epochs: int = 1000,
                   use_tqdm: bool = False):
        """
        Train the GP surrogate.

        :param samples: training data with shape (n, d).
        :param logp: logp of the training data with shape (n,).
        :param n_epochs: number of training epochs for GP.
        """
        if n_epochs < 0:
            raise ValueError("Number of epochs should be non-negative.")

        _samples = samples.to(self.device)
        _logp = logp.to(self.device)

        self.create_flow(_samples, _logp)
        self.gp_surrogate.train()
        self.gp_likelihood.train()
        self.gp_surrogate.requires_grad_(True)
        self.gp_likelihood.requires_grad_(True)

        optimizer = self.get_optimizers()[0]
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.gp_likelihood, self.gp_surrogate)
        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for _ in pbar:
            optimizer.zero_grad()
            output = self.gp_surrogate(_samples)
            loss = -mll(output, _logp)
            loss.backward()
            optimizer.step()
            self.add_debug_data(x_train=_samples)
            self.debug_step()
        self.gp_likelihood.eval()
        self.gp_surrogate.eval()
        self.gp_likelihood.requires_grad_(False)
        self.gp_surrogate.requires_grad_(False)
        self.debug_animate()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def logq(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Evaluate the GP density at the supplied sample locations.
        """
        return self.gp_likelihood(self.gp_surrogate(x.to(self.device))).mean

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        raise NotImplementedError

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        raise NotImplementedError

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def inverse_paths(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def get_optimizers(self):
        return [optim.Adam(self.gp_surrogate.parameters(), **self.optimizer_kwargs)]


class HierarchicalFlowInterface(TorchFlowInterface):
    class Graph:
        def __init__(self, n_vertices):
            """
            Directed graph class with an implementation of topological sort.
            Taken from https://www.geeksforgeeks.org/python-program-for-topological-sorting/.
            """
            self.adjacency_list = defaultdict(list)
            self.n_vertices = n_vertices  # No. of vertices

        def add_edge(self, u, v):
            self.adjacency_list[u].append(v)

        def add_edges(self, edge_list):
            for u, v in edge_list:
                self.add_edge(u, v)

        def _topological_sort_util(self, v, visited, stack):
            # Mark the current node as visited.
            visited[v] = True

            # Recur for all the vertices adjacent to this vertex
            for i in self.adjacency_list[v]:
                if not visited[i]:
                    self._topological_sort_util(i, visited, stack)

            # Push current vertex to stack which stores result
            stack.insert(0, v)

        def topological_sort(self):
            # Mark all the vertices as not visited
            visited = [False] * self.n_vertices
            stack = []

            # Call the recursive helper function to store Topological
            # Sort starting from all vertices one by one
            for i in range(self.n_vertices):
                if not visited[i]:
                    self._topological_sort_util(i, visited, stack)

            return stack

    class CouplingTransform(nn.Module):
        def __init__(self, src_dim, dst_dim, hidden_dim=5):
            super().__init__()
            self.src_dim = src_dim
            self.dst_dim = dst_dim
            self.hidden_dim = hidden_dim

            self.transform = nn.Sequential(
                nn.Linear(src_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, dst_dim * 2)
            )

        def forward(self, x: torch.Tensor):
            out = self.transform(x)
            s = out[..., :self.dst_dim]
            t = out[..., self.dst_dim:]
            return s, t

    def __init__(self,
                 rv_mask: torch.Tensor,
                 dag_edges: List[Tuple[int, int]],
                 interfaces: List[TorchFlowInterface] = None,
                 coupling_repeats: Union[Dict[int, int], int] = 1,
                 debugger: Optional[FlowDebugger] = None,
                 device=torch.device('cpu'),
                 optimizer_kwargs: dict = None):
        """
        Flow interface for hierarchical modeling.

        We assume there is an underlying directed acyclic graph (DAG) of random variables of groups of random variables
        (random vectors). The directed edges correspond to conditional distributions, i.e. if we have an edge (u, v),
        then the random vector v depends on u. The dependence is modeled with a coupling transform. Coupling transforms
        have a neural network which receives the input u and outputs vectors s, t with the same dimensionality as v.
        The vector u is kept the same while v is transformed using exp(s) * v + t. This is performed for all edges
        (u, v) until the entire DAG has been topologically traversed. Afterwards, each transformed vector is modeled
        with its respective normalizing flow.

        :param rv_mask: array where rv_mask[i] = k means that the i-th element of the input x is modeled by the k-th
            flow (the node with index k in the DAG). rv_mask[i] = k means that x_k belongs to the k-th random vector.
        :param dag_edges: directed edges between random vectors. Each element of the list is a tuple with format
            (source_vectors, destination_vectors). Cycles are not allowed.
        :param interfaces: flow interfaces corresponding to vectors at the same index. In other words, which flows
            are used to model different random vectors. Flow interfaces need to be implemented in PyTorch, otherwise
            we cannot put them in the same computation graph as PyTorch coupling layers. Flow interfaces must already
            have flow objects, i.e. create_flow should be called prior to passing interfaces as arguments to this
            constructor. If None, all interfaces are MAFs.
        """
        super().__init__(debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs, n_dim=len(rv_mask))

        self.n_dim = len(rv_mask)

        rvs, counts = torch.unique(rv_mask, return_counts=True)
        assert rvs[0] == 0 and len(rvs) == (rvs[-1] + 1)  # Ensure random vectors are zero indexed
        self.n_rvs = len(rvs)
        self.rv_dims = counts
        self.dag_edges = dag_edges
        self.rv_mask = rv_mask

        self.coupling_repeats = {i: 0 for i in range(self.n_rvs)}
        if type(coupling_repeats) == int:
            self.coupling_repeats = {k: coupling_repeats for k in self.coupling_repeats}
        else:
            self.coupling_repeats.update(coupling_repeats)

        if interfaces is None:
            self.interfaces = [MAFInterface(n_dim=int(d)) for d in self.rv_dims]
            for interface in self.interfaces:
                interface.create_flow()
        else:
            assert self.n_rvs == len(interfaces)
            self.interfaces = interfaces

        dag = self.Graph(n_vertices=self.n_rvs)
        dag.add_edges(self.dag_edges)
        self.sorted_rvs = dag.topological_sort()

        self.coupling_transforms: Dict[int, List[HierarchicalFlowInterface.CouplingTransform]] = dict()
        self.dst_src_rvs_dict = {i: sorted([e[0] for e in self.dag_edges if e[1] == i]) for i in range(self.n_rvs)}

    def add_debug_data(self, x_train: torch.Tensor, x_val: torch.Tensor = None, scalars: dict = None):
        super().add_debug_data(x_train=x_train, x_val=x_val, scalars=scalars)
        for interface in self.interfaces:
            interface.add_debug_data(x_train=x_train, x_val=x_val, scalars=scalars)

    def debug_step(self):
        super().debug_step()
        for interface in self.interfaces:
            interface.debug_step()

    def debug_animate(self):
        super().debug_animate()
        for interface in self.interfaces:
            interface.debug_animate()

    def create_flow(self, *args, **kwargs):
        self.coupling_transforms = {
            dst: [
                self.CouplingTransform(sum(self.rv_dims[src] for src in src_list), self.rv_dims[dst]).to(self.device)
                for _ in range(self.coupling_repeats[dst]) if src_list
            ] for dst, src_list in self.dst_src_rvs_dict.items()
        }

    def src_rvs(self, dst_idx):
        # Return the indices of random vectors that point towards rv_idx.
        unsorted_rvs = [u for (u, v) in self.dag_edges if v == dst_idx]
        sorted_rvs = [u for u in self.sorted_rvs if u in unsorted_rvs]
        return sorted_rvs

    def get_coupling_transform_parameters(self, rv_idx: int = None):
        coupling_transforms_parameters = []
        for dst_idx, ct_list in self.coupling_transforms.items():
            for ct in ct_list:
                if rv_idx is not None:
                    if rv_idx == dst_idx:
                        coupling_transforms_parameters.extend(list(ct.parameters()))
                else:
                    coupling_transforms_parameters.extend(list(ct.parameters()))
        return coupling_transforms_parameters

    def train_flow(self,
                   x: torch.Tensor,
                   n_epochs: int = 1000,
                   weights: torch.Tensor = None,
                   use_tqdm: bool = False,
                   x_val: Optional[torch.Tensor] = None):
        if weights is None:
            weights = torch.ones(len(x), device=self.device)
        if len(weights) != len(x):
            raise ValueError(f"x and weights should match in dimension 0 but got sizes {len(x)} and {len(weights)}.")
        if torch.any(weights < 0):
            raise ValueError("Weights should be non-negative.")
        if n_epochs < 0:
            raise ValueError("Number of epochs should be non-negative.")

        _x = x.to(self.device)
        _weights = weights.to(self.device) if weights is not None else None
        _x_val = x_val.to(self.device) if x_val is not None else None

        for flow_interface in self.interfaces:
            flow_interface.flow.train()

        interface_optimizers = self.get_optimizers()
        coupling_transforms_parameters = self.get_coupling_transform_parameters()
        if coupling_transforms_parameters:
            coupling_transform_optimizer = optim.Adam(coupling_transforms_parameters)
            optimizers = interface_optimizers + [coupling_transform_optimizer]
        else:
            optimizers = interface_optimizers

        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for _ in pbar:
            for optimizer in optimizers:
                optimizer.zero_grad()
            loss = -torch.sum(_weights * self.logq(_x)) / torch.sum(_weights)
            loss.backward()
            for optimizer in optimizers:
                optimizer.step()
            self.add_debug_data(x_train=_x, x_val=_x_val)
            self.debug_step()
        for flow_interface in self.interfaces:
            flow_interface.flow.eval()
        self.debug_animate()

    def coupling_forward_with_logj(self, x: torch.Tensor, dst_rv_idx: int = None):
        _x = torch.clone(x).to(self.device)
        logj = torch.zeros(len(x), device=self.device)

        for rv_idx in self.sorted_rvs:
            src_rvs = self.src_rvs(rv_idx)
            for transform in self.coupling_transforms[rv_idx]:
                s, t = transform(_x[:, torch.isin(self.rv_mask, torch.tensor(src_rvs))])
                _x[:, self.rv_mask == rv_idx] = _x[:, self.rv_mask == rv_idx] * torch.exp(s) + t
                logj += torch.sum(s, dim=1)
            if rv_idx == dst_rv_idx:
                break
        return _x, logj

    def coupling_forward(self, x: torch.Tensor, dst_rv_idx: int = None):
        return self.coupling_forward_with_logj(x, dst_rv_idx)[0]

    def coupling_logj_forward(self, x: torch.Tensor, dst_rv_idx: int = None):
        return self.coupling_forward_with_logj(x, dst_rv_idx)[1]

    def coupling_inverse_with_logj(self, z: torch.Tensor):
        _z = torch.clone(z).to(self.device)
        logj = torch.zeros(len(z), device=self.device)

        for dst_rv_idx in self.sorted_rvs[::-1]:
            src_rvs = self.src_rvs(dst_rv_idx)
            for transform in self.coupling_transforms[dst_rv_idx][::-1]:
                s, t = transform(_z[:, torch.isin(self.rv_mask, torch.tensor(src_rvs))])
                _z[:, self.rv_mask == dst_rv_idx] = (_z[:, self.rv_mask == dst_rv_idx] - t) / torch.exp(s)
                logj -= torch.sum(s, dim=1)
        return _z, logj

    def coupling_inverse(self, x: torch.Tensor):
        return self.coupling_inverse_with_logj(x)[0]

    def coupling_logj_inverse(self, x: torch.Tensor):
        return self.coupling_inverse_with_logj(x)[1]

    def flow_forward_with_logj(self, x: torch.Tensor, **kwargs):
        _x = torch.clone(x).to(self.device)
        logj = torch.zeros(len(x), device=self.device)

        # Apply flow on each random vector
        for rv_idx in range(self.n_rvs):
            _x_tmp, _logj_forward = self.interfaces[rv_idx].forward_with_logj(
                _x[:, self.rv_mask == rv_idx],
                **kwargs
            )
            _x[:, self.rv_mask == rv_idx] = _x_tmp
            logj += _logj_forward

        return _x, logj

    def flow_forward(self, x: torch.Tensor, **kwargs):
        return self.flow_forward_with_logj(x, **kwargs)[0]

    def flow_forward_logj(self, x: torch.Tensor, **kwargs):
        return self.flow_forward_with_logj(x, **kwargs)[1]

    def flow_inverse_with_logj(self, z: torch.Tensor, **kwargs):
        _z = torch.clone(z).to(self.device)
        logj = torch.zeros(len(z), device=_z.device)

        # Apply flow inverses
        for rv_idx in range(self.n_rvs):
            _z_tmp, _logj_backward = self.interfaces[rv_idx].inverse_with_logj(
                _z[:, self.rv_mask == rv_idx],
                **kwargs
            )
            _z[:, self.rv_mask == rv_idx] = _z_tmp
            logj += _logj_backward

        return _z, logj

    def flow_inverse(self, x: torch.Tensor, **kwargs):
        return self.flow_inverse_with_logj(x, **kwargs)[0]

    def flow_inverse_logj(self, x: torch.Tensor, **kwargs):
        return self.flow_inverse_with_logj(x, **kwargs)[1]

    def train_flow_stage(self,
                         x: torch.Tensor,
                         rv_idx: int,
                         n_epochs: int = 1000,
                         weights: torch.Tensor = None,
                         use_tqdm: bool = False):
        """
        Train a single flow in the DAG. This only trains the marginal at rv_idx, so it may not be that good.
        """
        flow_interface = self.interfaces[rv_idx]
        flow_interface.flow.train()
        interface_optimizers = flow_interface.get_optimizers()

        # Move variables to self.device
        _weights = (weights if weights is not None else torch.ones(len(x))).to(self.device)
        _x = torch.clone(x).to(self.device)

        # Jointly train the coupling transform and the flow
        if self.coupling_transforms[rv_idx]:
            coupling_transforms_parameters = self.get_coupling_transform_parameters(rv_idx)
            coupling_transform_optimizer = optim.Adam(coupling_transforms_parameters, **self.optimizer_kwargs)
            optimizers = interface_optimizers + [coupling_transform_optimizer]
        else:
            optimizers = interface_optimizers

        pbar = tqdm(range(n_epochs)) if use_tqdm else range(n_epochs)
        for _ in pbar:
            for optimizer in optimizers:
                optimizer.zero_grad()

            # Apply coupling transforms until we get to the current stage.
            loss = -torch.sum(_weights * flow_interface.logq(
                self.coupling_forward(_x, rv_idx)[:, self.rv_mask == rv_idx]
            )) / torch.sum(_weights)
            loss.backward()
            for optimizer in optimizers:
                optimizer.step()
            self.add_debug_data(x_train=_x)
            self.debug_step()
        flow_interface.flow.eval()
        self.debug_animate()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        z = torch.randn(n_samples, self.n_dim, device=self.device)
        return self.inverse(z, **kwargs)

    def logq(self, x: torch.Tensor) -> torch.Tensor:
        z, logj_forward = self.forward_with_logj(x.to(self.device))
        return logj_forward - self.n_dim / 2 * torch.log(torch.tensor(2 * np.pi)) - torch.sum(z ** 2, dim=1) / 2

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        # Apply coupling transforms
        _x, logj_coupling = self.coupling_forward_with_logj(x)

        # Apply flow on each random variable
        _x, logj_flow = self.flow_forward_with_logj(_x, **kwargs)
        logj = logj_coupling + logj_flow

        return _x, logj

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        # Apply flow on each random variable
        _z, logj_flow = self.flow_inverse_with_logj(z)

        # Apply coupling transform inverses
        _z, logj_coupling = self.coupling_inverse_with_logj(_z)

        logj = logj_flow + logj_coupling

        return _z, logj

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        # This method should output a fixed size tensor.
        # Different flows may have a different number of layers.
        # We get around this problem by repeating the final samples as many times as needed to reach a fixed size.
        x_transformed = self.coupling_forward(x)
        all_interface_paths = dict()
        for rv_idx, interface in enumerate(self.interfaces):
            interface_paths = interface.forward_paths(x_transformed[:, self.rv_mask == rv_idx])
            all_interface_paths[rv_idx] = interface_paths
        max_path_length = max(v.shape[0] for v in all_interface_paths.values())
        all_interface_paths_array = torch.zeros(
            size=(max_path_length, *x.shape),
            dtype=torch.float,
            device=torch.device('cpu')
        )
        for rv_idx, interface_paths in all_interface_paths.items():
            len_tmp = interface_paths.shape[0]
            all_interface_paths_array[:len_tmp, :, rv_idx == self.rv_mask] = interface_paths
            all_interface_paths_array[len_tmp:, :, rv_idx == self.rv_mask] = interface_paths[-1]
        # (L, n, n_dim)
        return torch.concat([x.unsqueeze(0), all_interface_paths_array])

    def inverse_paths(self, z: torch.Tensor):
        # This method should output a fixed size tensor.
        # Different flows may have a different number of layers.
        # We get around this problem by repeating the final samples as many times as needed to reach a fixed size.
        all_interface_paths = dict()
        for rv_idx, interface in enumerate(self.interfaces):
            interface_paths = interface.inverse_paths(z[:, self.rv_mask == rv_idx])
            all_interface_paths[rv_idx] = interface_paths
        max_path_length = max(v.shape[0] for v in all_interface_paths.values())
        all_interface_paths_array = torch.zeros(
            size=(max_path_length, *z.shape),
            dtype=torch.float,
            device=torch.device('cpu')
        )
        for rv_idx, interface_paths in all_interface_paths.items():
            len_tmp = interface_paths.shape[0]
            all_interface_paths_array[:len_tmp, :, rv_idx == self.rv_mask] = interface_paths
            all_interface_paths_array[len_tmp:, :, rv_idx == self.rv_mask] = interface_paths[-1]
        # (L, n, n_dim)
        x = self.coupling_inverse(all_interface_paths_array[-1])
        return torch.concat([all_interface_paths_array, x.unsqueeze(0)])

    def get_optimizers(self):
        optimizers = []
        for interface in self.interfaces:
            optimizers.extend(interface.get_optimizers())
        return optimizers


class HierarchicalSINFInterface(TorchFlowInterface):
    def __init__(self, rv_mask: torch.Tensor, dag_edges: List[Tuple[int, int]], debugger: Optional[FlowDebugger] = None,
                 device=torch.device('cpu'), optimizer_kwargs: dict = None):
        super().__init__(debugger=debugger, device=device, optimizer_kwargs=optimizer_kwargs)
        self.rv_mask = rv_mask
        self.information_transfer_interface: Optional[SINFInterface] = None
        self.rv_interfaces: Dict[int, SINFInterface] = dict()
        self.dag_edges = dag_edges

        rvs, counts = torch.unique(rv_mask, return_counts=True)
        assert rvs[0] == 0 and len(rvs) == (rvs[-1] + 1)  # Ensure random vectors are zero indexed
        self.n_rvs = len(rvs)
        self.rv_dims = counts

    def get_optimizers(self):
        return []

    def create_flow(self, x: torch.Tensor, weights: torch.Tensor = None, val_frac: float = 0.0, iteration=5, **kwargs):
        if 'n_epochs' in kwargs:
            del kwargs['n_epochs']

        information_transfer_interface = SINFInterface()
        information_transfer_interface.create_flow(
            x=x,
            weights=weights,
            val_frac=val_frac,
            iteration=iteration,
            **kwargs
        )  # TODO incorporate information transfer loss
        x_transformed = information_transfer_interface.forward(x)

        rv_interfaces = dict()
        for rv_index in range(self.n_rvs):
            rv_interfaces[rv_index] = SINFInterface()
            rv_interfaces[rv_index].create_flow(
                x=x_transformed[:, self.rv_mask == rv_index],
                weights=weights,
                val_frac=val_frac,
                iteration=iteration,
                **kwargs
            )

        self.information_transfer_interface = information_transfer_interface
        self.rv_interfaces = rv_interfaces

    def train_flow(self, *args, **kwargs):
        self.create_flow(*args, **kwargs)

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        return self.sample_with_logq(n_samples)[0]

    def logq(self, x: torch.Tensor) -> torch.Tensor:
        x_transformed, logj = self.information_transfer_interface.forward_with_logj(x)
        logq = logj
        for rv_index in range(self.n_rvs):
            logq += self.rv_interfaces[rv_index].logq(x_transformed[:, self.rv_mask == rv_index])
        return logq

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        x_transformed, logj = self.information_transfer_interface.forward_with_logj(x)
        z = torch.zeros_like(x)
        for rv_index in range(self.n_rvs):
            z_component, logj_component = self.rv_interfaces[rv_index].forward_with_logj(
                x_transformed[:, self.rv_mask == rv_index]
            )
            logj += logj_component
            z[:, self.rv_mask == rv_index] = z_component
        return z, logj

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        z_components: Dict[int, torch.Tensor] = {
            rv_index: z[:, self.rv_mask == rv_index] for rv_index in range(self.n_rvs)
        }  # Split z into components
        x_transformed = torch.zeros_like(z)
        logj = torch.zeros(len(z))
        for rv_index in range(self.n_rvs):
            x_transformed_component, logj_component = self.rv_interfaces[rv_index].inverse_with_logj(
                z_components[rv_index]
            )
            logj += logj_component
            x_transformed[:, self.rv_mask == rv_index] = x_transformed_component
        x, logj_tmp = self.information_transfer_interface.inverse_with_logj(x_transformed)
        logj += logj_tmp
        return x, logj

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def inverse_paths(self, z: torch.Tensor):
        raise NotImplementedError


class EcoFlowInterface(TorchFlowInterface):
    def __init__(self,
                 lr: float = 1e-1,
                 weight_decay: float = 1e-4,
                 **kwargs):
        """
        Interface for EcoFlow.

        Intended for modeling posteriors in hierarchical Bayesian models.
        This model should converge quite quickly (in iterations and wall time) and should outperform generic flows in
        high dimensional, small data regimes.

        Training advice:
        * It is usually safe to keep the learning rate at around 1e-1 or 1e-2. Smaller learning rates will usually get
          to the same result, but just take longer to do so.
        * If you have N samples for training, then it is usually safe to set N / 2 as x_train and N / 2 as x_val.
        * It is not advisable to omit validation data. Validation data is about as important as training data and is
          meant for early stopping to prevent overfitting. If you skip validation data, you will usually still beat
          other flows, but splitting your dataset in halves like described above is basically a free improvement.
        * A large weight decay may be too strong due to the small number of parameters. Reduce it if fits are poor.
        * If you still have fitting issues after reducing weight decay, there may not be enough parameters in the flow.
          You can fix this by setting hidden_features to something larger (e.g. 10, 20 or 50). You can also try
          increasing the number of parameters by setting num_blocks to something larger (e.g. 1, 2 or 5).
        * If you are fitting data where you know that a particular group of dimensions has the exact same distribution
          (e.g. the conditional Gaussian dimensions in Neal's funnel), you may set pointwise_transform_type to
          'group_affine'. This makes all conditional dimensions share the coupling parameters and considerably reduces
          the number of trainable parameters in the flow. Note that this currently affects all groups in the problem,
          including the ones that should not actually share parameters. We will make this more flexible in the future.
        * If you are fitting data where dimensions are not Gaussians or conditional Gaussians, it may help to set
          pointwise_transform_type to 'spline' or 'group_spline'. However, training becomes somewhat slower at larger
          training set sizes due to the suboptimal nflows implementation of rational quadratic splines. You may also see
          worse fits at small training set sizes, because of the greater number of trainable parameters in splines.

        Example use:
        >>> from nfmc_jax.hierarchical.problems import Funnel
        >>> torch.manual_seed(0)
        >>> problem = Funnel(1000)
        >>> x_train = problem.sample(10)
        >>> x_val = problem.sample(10)
        >>> interface = EcoFlowInterface()
        >>> interface.create_flow(problem.graph())
        >>> interface.train_flow(x_train, x_val)

        :param lr: learning rate for the AdamW optimizer.
        :param weight_decay: weight decay for the AdamW optimizer.
        """
        super().__init__(**kwargs)

        self.optimizer_kwargs = dict(
            lr=lr,
            weight_decay=weight_decay
        )

    def get_optimizers(self):
        return [optim.AdamW(self.flow.parameters(), **self.optimizer_kwargs)]

    def create_flow(self, dag: nx.DiGraph, **kwargs):
        """
        Create EcoFlow according to the given hierarchical model DAG.

        Nodes in dag represent groups of dimensions whose distributions depend on the same predecessor groups.
        Dimensions within a group are assumed to be independent unless a group pointwise transform is specified.
        Each node should have the 'n_dim' attribute, indicating the number of dimensions it covers.

        An edge u -> v indicates that the distributions of v depend on u.
        Each edge should have the 'mask' attribute. The mask value is a list of indices [i1, i2, ...]. This indicates
        that dimensions in a group depend on a subset of dimensions in the predecessor group. This is helpful when a
        predecessor group affects two separate groups and depencence specifics need to be stated.

        See nfmc_jax.hierarchical.problems for reference DAGs.

        :param dag: networkx DiGraph object representing the DAG.
        :param kwargs: keyword arguments for build_ecoflow.
        """
        self.flow = build_ecoflow(dag, **kwargs)

    def train_flow(self,
                   x_train: torch.Tensor,
                   x_val: torch.Tensor = None,
                   n_epochs: int = 20000,
                   patience: int = 1000,
                   use_tqdm: bool = True):
        """
        Train EcoFlow.

        :param x_train: torch tensor of training samples with shape (n_train, d).
        :param x_val: torch tensor of validation samples with shape (n_val, d). Used for early stopping.
        :param n_epochs: number of epochs.
        :param patience: stop training if we go this many epochs without a validation loss improvement.
        :param use_tqdm: use a progress bar.
        """
        optimizer = self.get_optimizers()[0]
        best_state = deepcopy(self.flow.state_dict())
        best_loss_val = torch.inf
        best_epoch = 0
        for epoch in (pbar := tqdm(range(n_epochs), disable=not use_tqdm)):
            optimizer.zero_grad()
            loss = -self.flow.log_prob(x_train).mean()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                if x_val is not None:
                    loss_val = -self.flow.log_prob(x_val).mean()
                    if loss_val < best_loss_val:
                        best_loss_val = loss_val
                        best_state = deepcopy(self.flow.state_dict())
                        best_epoch = epoch
                    if epoch - best_epoch > patience:
                        break
                    if torch.isnan(loss_val) or torch.isinf(loss_val):
                        break
                    pbar.set_postfix_str(
                        f'Train loss: {float(loss):.4f} | '
                        f'Val loss: {float(loss_val):.4f} [{float(best_loss_val):.4f} @ {best_epoch}]'
                    )
                else:
                    if torch.isnan(loss) or torch.isinf(loss):
                        break
                    pbar.set_postfix_str(f'Train loss: {float(loss):.4f}')

        if x_val is not None:
            self.flow.load_state_dict(best_state)
        self.flow.eval()

    def sample(self, n_samples: int, **kwargs) -> torch.Tensor:
        return self.flow.sample(n_samples, **kwargs)

    def logq(self, x: torch.Tensor) -> torch.Tensor:
        return self.flow.log_prob(x)

    def forward_with_logj(self, x: torch.Tensor, **kwargs):
        z, logj_forward = self.flow._transform(x.to(self.device))
        return z, logj_forward

    def inverse_with_logj(self, z: torch.Tensor, **kwargs):
        x, logj_backward = self.flow._transform.inverse(z.to(self.device))
        return x, logj_backward

    def forward_paths(self, x: torch.Tensor) -> torch.Tensor:
        paths = [torch.clone(x.to(self.device))]
        for layer in self.flow._transform._transforms:
            paths.append(layer(paths[-1], None)[0])  # context = None
        paths = torch.stack(paths)
        return paths

    def inverse_paths(self, x: torch.Tensor) -> torch.Tensor:
        paths = [torch.clone(x.to(self.device))]
        for layer in [t.inverse for t in self.flow._transform._transforms[::-1]]:
            paths.append(layer(paths[-1], None)[0])  # context = None
        paths = torch.stack(paths)
        return paths


if __name__ == '__main__':
    from nfmc_jax.flows.debug import VisualizerConfig, FileWriterConfig

    torch.manual_seed(0)
    data_train = torch.randn(100, 5) * 0.1 + 5
    data_val = torch.randn(100, 5) * 0.1 + 5

    # This is how to create regular interfaces
    sinf = SINFInterface()
    snf = SNFInterface()
    maf = MAFInterface(n_dim=5)

    # To enable debugging for normalizing flows, create a writer and/or a visualizer configuration object and pass them
    # to a debugger. The debugger should then be passed to a flow interface. As the interface is trained, the debugger
    # executes instructions according to configuration files. Note that some interfaces may not work with a debugger.
    # To make animations, you need to specify that the relevant data will be saved with a writer configuration object.
    debugger = FlowDebugger(
        directory=pathlib.Path('visualization-tests/snf'),
        delete_existing=True,
        visualizer_config=VisualizerConfig(
            save_figures=True,

            plot_training_paths=True,
            plot_validation_paths=True,
            plot_generative_paths=True,
            plot_scalars=False,
            plot_training_reconstructions=True,
            plot_validation_reconstructions=True,

            animate_training_paths=True,
            animate_validation_paths=True,
            animate_generative_paths=True,
            animate_training_reconstructions=True,
            animate_validation_reconstructions=True
        ),
        file_writer_config=FileWriterConfig(
            write_training_paths=True,
            write_validation_paths=True,
            write_generative_paths=True,
            write_training_reconstructions=True,
            write_validation_reconstructions=True,
            write_scalars=True
        )
    )

    # Create the normalizing flow with a debugger
    interface = SNFInterface(debugger=debugger)
    interface.create_flow(x_train=data_train, iteration=10, alpha=(0, 0.98))
    # snf.create_flow(x_train=data_train, data_validate=data_val, alpha=(0, 0.98))

    # Train the normalizing flow
    interface.train_flow(
        x=data_train,
        x_val=data_val,
        lr_Psi=1e-7,
        lr_A=1e-7,
        n_epochs=30,
        reg2=1e-9,
        use_tqdm=True
    )
