import copy
import math
import os
from typing import Any, List, Optional, Union

import numpy as np
import torch
import torchsde
from pytorch_lightning import LightningDataModule, LightningModule
from torch.distributions import MultivariateNormal
from torchdyn.core import NeuralODE
from torchvision import transforms

from .components.augmentation import (
    AugmentationModule,
    AugmentedVectorField,
    Sequential,
)
from .components.distribution_distances import compute_distribution_distances
from .components.optimal_transport import OTPlanSampler
from .components.plotting import (
    plot_paths,
    plot_samples,
    plot_trajectory,
    store_trajectories,
)
from .components.schedule import ConstantNoiseScheduler, NoiseScheduler
from .components.solver import FlowSolver
from .utils import get_wandb_logger


class CFMLitModule(LightningModule):
    """Conditional Flow Matching Module for training generative models and models over time."""

    def __init__(
        self,
        net: Any,
        optimizer: Any,
        datamodule: LightningDataModule,
        augmentations: AugmentationModule,
        partial_solver: FlowSolver,
        scheduler: Optional[Any] = None,
        neural_ode: Optional[Any] = None,
        ot_sampler: Optional[Union[str, Any]] = None,
        sigma_min: float = 0.1,
        avg_size: int = -1,
        leaveout_timepoint: int = -1,
        test_nfe: int = 100,
        plot: bool = False,
        nice_name: str = "CFM",
    ) -> None:
        """Initialize a conditional flow matching network either as a generative model or for a
        sequence of timepoints.

        Note: DDP does not currently work with NeuralODE objects from torchdyn
        in the init so we initialize them every time we need to do a sampling
        step.

        Args:
            net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
            optimizer: partial torch.optimizer missing parameters.
            datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
            ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
            sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations.
            leaveout_timepoint: which (if any) timepoint to leave out during the training phase
            plot: if true, log intermediate plots during validation
        """
        super().__init__()
        self.save_hyperparameters(
            ignore=[
                "net",
                "optimizer",
                "scheduler",
                "datamodule",
                "augmentations",
                "partial_solver",
            ],
            logger=False,
        )
        self.datamodule = datamodule
        self.is_trajectory = False
        if hasattr(datamodule, "IS_TRAJECTORY"):
            self.is_trajectory = datamodule.IS_TRAJECTORY
        # dims is either an integer or a tuple. This helps us to decide whether to process things as
        # a vector or as an image.
        if hasattr(datamodule, "dim"):
            self.dim = datamodule.dim
            self.is_image = False
        elif hasattr(datamodule, "dims"):
            self.dim = datamodule.dims
            self.is_image = True
        else:
            raise NotImplementedError("Datamodule must have either dim or dims")
        self.net = net(dim=self.dim)
        self.augmentations = augmentations
        self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
        self.val_augmentations = AugmentationModule(
            # cnf_estimator=None,
            l1_reg=1,
            l2_reg=1,
            squared_l2_reg=1,
        )
        self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
        if neural_ode is not None:
            self.aug_node = Sequential(
                self.augmentations.augmenter,
                neural_ode(self.aug_net),
            )

        self.partial_solver = partial_solver
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.ot_sampler = ot_sampler
        if ot_sampler == "None":
            self.ot_sampler = None
        if isinstance(self.ot_sampler, str):
            # regularization taken for optimal Schrodinger bridge relationship
            self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)
        self.criterion = torch.nn.MSELoss()

    def forward_integrate(self, batch: Any, t_span: torch.Tensor):
        """Forward pass with integration over t_span intervals.

        (t, x, t_span) -> [x_t_span].
        """
        X = self.unpack_batch(batch)
        X_start = X[:, t_span[0], :]
        traj = self.node.trajectory(X_start, t_span=t_span)
        return traj

    def forward(self, t: torch.Tensor, x: torch.Tensor):
        """Forward pass (t, x) -> dx/dt."""
        return self.net(t, x)

    def unpack_batch(self, batch):
        """Unpacks a batch of data to a single tensor."""
        if self.is_trajectory:
            return torch.stack(batch, dim=1)
        if not isinstance(self.dim, int):
            # Assume this is an image classification dataset where we need to strip the targets
            return batch[0]
        return batch

    def preprocess_batch(self, X, training=False):
        """Converts a batch of data into matched a random pair of (x0, x1)"""
        t_select = torch.zeros(1, device=X.device)
        if self.is_trajectory:
            batch_size, times, dim = X.shape
            if not hasattr(self.datamodule, "HAS_JOINT_PLANS"):
                # resample the OT plan
                # list of length t of tuples of length 2 of tensors of shape
                tmp_ot_list = []
                for t in range(times - 1):
                    if training and t + 1 == self.hparams.leaveout_timepoint:
                        tmp_ot = torch.stack((X[:, t], X[:, t + 2]))
                    else:
                        tmp_ot = torch.stack((X[:, t], X[:, t + 1]))
                    if (
                        training
                        and self.ot_sampler is not None
                        and t != self.hparams.leaveout_timepoint
                    ):
                        tmp_ot = torch.stack(self.ot_sampler.sample_plan(tmp_ot[0], tmp_ot[1]))

                    tmp_ot_list.append(tmp_ot)
                tmp_ot_list = torch.stack(tmp_ot_list)
                # randomly sample a batch

            if training and self.hparams.leaveout_timepoint > 0:
                # Select random except for the leftout timepoint
                t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)
                t_select[t_select >= self.hparams.leaveout_timepoint] += 1
            else:
                t_select = torch.randint(times - 1, size=(batch_size,))
            x0 = []
            x1 = []
            for i in range(batch_size):
                ti = t_select[i]
                ti_next = ti + 1
                if training and ti_next == self.hparams.leaveout_timepoint:
                    ti_next += 1
                if hasattr(self.datamodule, "HAS_JOINT_PLANS"):
                    x0.append(torch.tensor(self.datamodule.timepoint_data[ti][X[i, ti]]))
                    pi = self.datamodule.pi[ti]
                    if training and ti + 1 == self.hparams.leaveout_timepoint:
                        pi = self.datamodule.pi_leaveout[ti]
                    index_batch = X[i][ti]
                    i_next = np.random.choice(
                        pi.shape[1], p=pi[index_batch] / pi[index_batch].sum()
                    )
                    x1.append(torch.tensor(self.datamodule.timepoint_data[ti_next][i_next]))
                else:
                    x0.append(tmp_ot_list[ti][0][i])
                    x1.append(tmp_ot_list[ti][1][i])
            x0, x1 = torch.stack(x0), torch.stack(x1)
        else:
            batch_size = X.shape[0]
            # If no trajectory assume generate from standard normal
            x0 = torch.randn_like(X)
            x1 = X
        return x0, x1, t_select

    def average_ut(self, x, t, mu_t, sigma_t, ut):
        pt = torch.exp(-0.5 * (torch.cdist(x, mu_t) ** 2) / (sigma_t**2))
        batch_size = x.shape[0]
        ind = torch.randint(
            batch_size, size=(batch_size, self.hparams.avg_size - 1)
        )  # randomly (non-repreat) sample m-many index
        # always include self
        ind = torch.cat([ind, torch.arange(batch_size)[:, None]], dim=1)
        pt_sub = torch.stack([pt[i, ind[i]] for i in range(batch_size)])
        ut_sub = torch.stack([ut[ind[i]] for i in range(batch_size)])
        p_sum = torch.sum(pt_sub, dim=1, keepdim=True)
        ut = torch.sum(pt_sub[:, :, None] * ut_sub, dim=1) / p_sum
        # Reduce batch size because they are all the same
        return x[:1], ut[:1], t[:1]

    def calc_mu_sigma(self, x0, x1, t):
        mu_t = t * x1 + (1 - t) * x0
        sigma_t = self.hparams.sigma_min
        return mu_t, sigma_t

    def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
        del x, t, mu_t, sigma_t
        return x1 - x0

    def calc_loc_and_target(self, x0, x1, t, t_select, training):
        """Computes the loss on a batch of data."""

        t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
        mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
        eps_t = torch.randn_like(mu_t)
        x = mu_t + sigma_t * eps_t
        ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)

        # if we are starting from right before the leaveout_timepoint then we
        # divide the target by 2
        if training and self.hparams.leaveout_timepoint > 0:
            ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
            t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2

        # p is the pair-wise conditional probability matrix. Note that this has to be torch.cdist(x, mu) in that order
        # t that network sees is incremented by first timepoint
        t = t + t_select.reshape(-1, *t.shape[1:])
        return x, ut, t, mu_t, sigma_t, eps_t

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""

        X = self.unpack_batch(batch)
        x0, x1, t_select = self.preprocess_batch(X, training)
        # Either randomly sample a single T or sample a batch of T's
        if self.hparams.avg_size > 0:
            t = torch.rand(1).repeat(X.shape[0]).type_as(X)
        else:
            t = torch.rand(X.shape[0]).type_as(X)
        # Resample the plan if we are using optimal transport
        if self.ot_sampler is not None and not self.is_trajectory:
            x0, x1 = self.ot_sampler.sample_plan(x0, x1)

        x, ut, t, mu_t, sigma_t, eps_t = self.calc_loc_and_target(x0, x1, t, t_select, training)

        if self.hparams.avg_size > 0:
            x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
        aug_x = self.aug_net(t, x, augmented_input=False)
        reg, vt = self.augmentations(aug_x)
        return torch.mean(reg), self.criterion(vt, ut)

    def training_step(self, batch: Any, batch_idx: int):
        reg, mse = self.step(batch, training=True)
        loss = mse + reg
        prefix = "train"
        self.log_dict(
            {f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},
            on_step=True,
            on_epoch=False,
            prog_bar=True,
        )
        return loss

    def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
        import os
        from math import prod

        from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
        from torchvision.utils import save_image

        #        val_augmentations = AugmentationModule(
        #            cnf_estimator="hutch",
        #            squared_l2_reg=1,
        #        )
        #        aug_dims = val_augmentations.aug_dims
        #        val_aug_net = AugmentedVectorField(self.net, val_augmentations.regs, self.dim)
        #        val_aug_node = Sequential(
        #            val_augmentations.augmenter,
        #            NeuralODE(val_aug_net, solver="euler", sensitivity="adjoint"),
        #        )
        #        t_span = torch.linspace(1, 0, 101)
        #        x = batch[0]
        #        os.makedirs("regularizations", exist_ok=True)
        #        for k in range(0):
        #            x_norm = cifar10_normalization()(x + (torch.rand_like(x) / 255))
        #            _, aug_traj = val_aug_node(x_norm, t_span)
        #            aug, traj = aug_traj[-1, :, :aug_dims], aug_traj[-1, :, aug_dims:]
        #            mn = MultivariateNormal(
        #                torch.zeros(prod(self.dim)).type_as(traj),
        #                torch.eye(prod(self.dim)).type_as(traj),
        #            )
        #            aug[:, 0] += mn.log_prob(traj.reshape(traj.shape[0], -1))
        #            np.save(
        #                f"regularizations/regs_{k}_{batch_idx}.npy",
        #                aug.detach().cpu().numpy(),
        #            )

        solver = self.partial_solver(self.net, self.dim)
        if isinstance(self.hparams.test_nfe, int):
            t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1)
        elif isinstance(self.hparams.test_nfe, str):
            solver.ode_solver = "tsit5"
            t_span = torch.linspace(0, 1, 2)
        else:
            raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}")
        traj = solver.odeint(torch.randn(batch[0].shape[0], *self.dim).type_as(batch[0]), t_span)[
            -1
        ]
        os.makedirs("images", exist_ok=True)
        mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]]
        std = [255.0 / x for x in [63.0, 62.1, 66.7]]
        inv_normalize = transforms.Compose(
            [
                transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
                transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]),
            ]
        )
        traj = inv_normalize(traj)
        traj = torch.clip(traj, min=0, max=1.0)
        for i, image in enumerate(traj):
            save_image(image, fp=f"images/{batch_idx}_{i}.png")
        return {"x": batch[0]}

    def eval_step(self, batch: Any, batch_idx: int, prefix: str):
        if prefix == "test" and self.is_image:
            self.image_eval_step(batch, batch_idx, prefix)
        shapes = [b.shape[0] for b in batch]

        if not self.is_image and prefix == "val" and shapes.count(shapes[0]) == len(shapes):
            reg, mse = self.step(batch, training=False)
            loss = mse + reg
            self.log_dict(
                {f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},
                on_step=False,
                on_epoch=True,
                sync_dist=True,
            )
            return {"loss": loss, "mse": mse, "reg": reg, "x": self.unpack_batch(batch)}

        return {"x": batch}

    def preprocess_epoch_end(self, outputs: List[Any], prefix: str):
        """Preprocess the outputs of the epoch end function."""
        if self.is_trajectory and prefix == "test" and isinstance(outputs[0]["x"], list):
            # x is jagged if doing a trajectory
            x = outputs[0]["x"]
            ts = len(x)
            x0 = x[0]
            x_rest = x[1:]
        elif self.is_trajectory:
            if hasattr(self.datamodule, "HAS_JOINT_PLANS"):
                x = [torch.tensor(dd) for dd in self.datamodule.timepoint_data]
                x0 = x[0]
                x_rest = x[1:]
                ts = len(x)
            else:
                v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]}
                x = v["x"]
                ts = x.shape[1]
                x0 = x[:, 0, :]
                x_rest = x[:, 1:]
        else:
            if isinstance(self.dim, int):
                v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]}
                x = v["x"]
            else:
                x = [d["x"] for d in outputs][0][0][:100]
            # Sample some random points for the plotting function
            rand = torch.randn_like(x)
            # rand = torch.randn_like(x, generator=torch.Generator(device=x.device).manual_seed(42))
            x = torch.stack([rand, x], dim=1)
            ts = x.shape[1]
            x0 = x[:, 0]
            x_rest = x[:, 1:]
        return ts, x, x0, x_rest

    def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
        # Build a trajectory
        t_span = torch.linspace(0, 1, 101)
        aug_dims = self.val_augmentations.aug_dims
        regs = []
        trajs = []
        full_trajs = []
        solver = self.partial_solver(self.net, self.dim)
        nfe = 0
        x0_tmp = x0.clone()

        if self.is_image:
            traj = solver.odeint(x0, t_span)
            full_trajs.append(traj)
            trajs.append(traj[0])
            trajs.append(traj[-1])
            nfe += solver.nfe

        if not self.is_image:
            solver.augmentations = self.val_augmentations
            for i in range(ts - 1):
                traj, aug = solver.odeint(x0_tmp, t_span + i)
                full_trajs.append(traj)
                traj, aug = traj[-1], aug[-1]
                x0_tmp = traj
                regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
                trajs.append(traj)
                nfe += solver.nfe

        full_trajs = torch.cat(full_trajs)

        if not self.is_image:
            regs = np.stack(regs).mean(axis=0)
            names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
            self.log_dict(dict(zip(names, regs)), sync_dist=True)

            # Evaluate the fit
            if (
                self.is_trajectory
                and prefix == "test"
                and isinstance(outputs[0]["x"], list)
                and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM")
            ):
                # Redo the solver for each timepoint
                trajs = []
                full_trajs = []
                nfe = 0
                x0_tmp = x0
                for i in range(ts - 1):
                    traj, _ = solver.odeint(x0_tmp, t_span + i)
                    traj = traj[-1]
                    x0_tmp = x_rest[i]
                    trajs.append(traj)
                    nfe += solver.nfe
                names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1])
            else:
                names, dists = compute_distribution_distances(trajs, x_rest)
            names = [f"{prefix}/{name}" for name in names]
            d = dict(zip(names, dists))
            if self.hparams.leaveout_timepoint >= 0:
                to_add = {
                    f"{prefix}/t_out/{key.split('/')[-1]}": val
                    for key, val in d.items()
                    if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
                }
                d.update(to_add)
            d[f"{prefix}/nfe"] = nfe

            self.log_dict(d, sync_dist=True)

        if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
            solver.augmentations = None
            # t_span = torch.linspace(0, 1, 101)
            # traj = solver.odeint(x0, t_span)
            # t_span = t_span[::5]
            # traj = traj[::5]
            t_span = torch.linspace(0, 1, 21)
            traj = solver.odeint(x0, t_span)
            assert traj.shape[0] == t_span.shape[0]
            kls = [
                self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
            ]
            self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
            self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)

        return trajs, full_trajs

    def eval_epoch_end(self, outputs: List[Any], prefix: str):
        wandb_logger = get_wandb_logger(self.loggers)
        if prefix == "test" and self.is_image:
            os.makedirs("images", exist_ok=True)
            if len(os.listdir("images")) > 0:
                path = "./data/fid_stats_cifar10_train.npz"
                from pytorch_fid import fid_score

                fid = fid_score.calculate_fid_given_paths(["images", path], 256, "cuda", 2048, 0)
                self.log(f"{prefix}/fid", fid)

        ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix)
        trajs, full_trajs = self.forward_eval_integrate(ts, x0, x_rest, outputs, prefix)

        if self.hparams.plot:
            if isinstance(self.dim, int):
                plot_trajectory(
                    x,
                    full_trajs,
                    title=f"{self.current_epoch}_ode",
                    key="ode_path",
                    wandb_logger=wandb_logger,
                )
            else:
                plot_samples(
                    trajs[-1],
                    title=f"{self.current_epoch}_samples",
                    wandb_logger=wandb_logger,
                )

        if prefix == "test" and not self.is_image:
            store_trajectories(x, self.net)

    def validation_step(self, batch: Any, batch_idx: int):
        return self.eval_step(batch, batch_idx, "val")

    def validation_epoch_end(self, outputs: List[Any]):
        self.eval_epoch_end(outputs, "val")

    def test_step(self, batch: Any, batch_idx: int):
        return self.eval_step(batch, batch_idx, "test")

    def test_epoch_end(self, outputs: List[Any]):
        self.eval_epoch_end(outputs, "test")

    def configure_optimizers(self):
        """Pass model parameters to optimizer."""
        optimizer = self.optimizer(params=self.parameters())
        if self.scheduler is None:
            return optimizer

        scheduler = self.scheduler(optimizer)
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]

    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        scheduler.step(epoch=self.current_epoch)


class RectifiedFlowLitModule(CFMLitModule):
    def __init__(
        self,
        net: Any,
        optimizer: Any,
        datamodule: LightningDataModule,
        augmentations: AugmentationModule,
        partial_solver: FlowSolver,
        val_augmentations: Optional[AugmentationModule] = None,
        scheduler: Optional[Any] = None,
        neural_ode: Optional[Any] = None,
        ot_sampler: Optional[Union[str, Any]] = None,
        sigma_min: float = 0.1,
        rectify_epochs: Optional[List[int]] = None,
        test_nfe: int = 100,
        avg_size: int = -1,
        leaveout_timepoint: int = -1,
        plot: bool = False,
        nice_name: str = "Rect",
    ) -> None:
        """Initialize a conditional flow matching network either as a generative model or for a
        sequence of timepoints.

        Args:
            net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
            optimizer: partial torch.optimizer missing parameters.
            datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
            ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
            sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations.
            leaveout_timepoint: which (if any) timepoint to leave out during the training phase
            plot: if true, log intermediate plots during validation
        """
        super(CFMLitModule, self).__init__()
        self.save_hyperparameters(
            ignore=[
                "net",
                "optimizer",
                "scheduler",
                "datamodule",
                "augmentations",
                "val_augmentations",
                "partial_solver",
            ],
            logger=False,
        )
        self.datamodule = datamodule
        self.is_trajectory = False
        if hasattr(datamodule, "IS_TRAJECTORY"):
            self.is_trajectory = datamodule.IS_TRAJECTORY
        if hasattr(datamodule, "dim"):
            self.dim = datamodule.dim
            self.is_image = False
        elif hasattr(datamodule, "dims"):
            self.dim = datamodule.dims
            self.is_image = True
        else:
            raise NotImplementedError("Datamodule must have either dim or dims")
        self.net = net(dim=self.dim)
        self.frozen_net = None
        self.augmentations = augmentations
        self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
        self.val_augmentations = val_augmentations
        if val_augmentations is None:
            self.val_augmentations = AugmentationModule(
                l1_reg=1,
                l2_reg=1,
                squared_l2_reg=1,
            )
        self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
        if neural_ode is not None:
            self.aug_node = Sequential(
                self.augmentations.augmenter,
                neural_ode(self.aug_net),
            )
        self.partial_solver = partial_solver
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.ot_sampler = ot_sampler
        if ot_sampler == "None":
            self.ot_sampler = None
        if isinstance(self.ot_sampler, str):
            # regularization taken for optimal Schrodinger bridge relationship
            self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)
        self.criterion = torch.nn.MSELoss()

    def preprocess_batch(self, X, training=False):
        """Converts a batch of data into matched a random pair of (x0, x1)"""
        t_select = torch.zeros(1, device=X.device)
        if self.is_trajectory:
            batch_size, times, dim = X.shape
            if training and self.hparams.leaveout_timepoint > 0:
                # Select random except for the leftout timepoint
                t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)
                t_select[t_select >= self.hparams.leaveout_timepoint] += 1
            else:
                t_select = torch.randint(times - 1, size=(batch_size,))
            x0 = []
            x1 = []
            for i in range(batch_size):
                ti = t_select[i]
                ti_next = ti + 1
                if training and ti_next == self.hparams.leaveout_timepoint:
                    ti_next += 1
                x0.append(X[i, ti])
                x1.append(X[i, ti_next])
            x0, x1 = torch.stack(x0), torch.stack(x1)
        else:
            batch_size = X.shape[0]
            # If no trajectory assume generate from standard normal
            x0 = torch.randn_like(X)
            x1 = X

        if self.frozen_net is not None:
            # Currently only works for 2 distributions
            assert t_select[0] == 0
            t_span = torch.linspace(0, 1, 100)
            val_node = NeuralODE(self.frozen_net, solver="euler")
            with torch.no_grad():
                _, traj = val_node(x0, t_span)
                x1 = traj[-1]
        return x0, x1, t_select

    def training_epoch_end(self, training_step_outputs):
        if (
            self.hparams.rectify_epochs is not None
            and self.current_epoch in self.hparams.rectify_epochs
        ):
            self.frozen_net = copy.deepcopy(self.net)


class ActionMatchingLitModule(CFMLitModule):
    """Implements Action Matching: Learning Stochastic Dynamics from Samples (Neklyudov et al.
    2022)

    Requires net to have a .energy function where net.energy(t, x): \\mathbb{R}^{d+1} \to
    \\mathbb{R} and net.forward is equal to \nabla_x(net.energy).
    """

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""
        assert not self.is_trajectory
        energy = self.net.energy
        X = self.unpack_batch(batch)
        x0, x1, t_select = self.preprocess_batch(X, training)

        if self.ot_sampler is not None:
            x0, x1 = self.ot_sampler.sample_plan(x0, x1)

        t = torch.rand(X.shape[0]).type_as(X)
        t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
        xt = t_xshape * x1 + (1 - t_xshape) * x0
        # t that network sees is incremented by first timepoint
        t = t + t_select.reshape(-1, *t.shape[1:])

        xt.requires_grad, t_xshape.requires_grad = True, True
        with torch.set_grad_enabled(True):
            st = torch.sum(energy(torch.cat([xt, t_xshape], dim=-1)))
            dsdx, dsdt = torch.autograd.grad(st, (xt, t_xshape), create_graph=True)
        xt.requires_grad, t_xshape.requires_grad = False, False
        a0 = energy(torch.cat([x0, torch.zeros(x0.shape[0], 1)], dim=-1))
        a1 = energy(torch.cat([x1, torch.ones(x1.shape[0], 1)], dim=-1))
        loss = a0 - a1 + 0.5 * (dsdx**2).sum(1, keepdims=True) + dsdt
        loss = loss.mean()
        aug_x = self.aug_net(t, xt, augmented_input=False)
        reg, vt = self.augmentations(aug_x)
        return torch.mean(reg), loss


class VariancePreservingCFM(CFMLitModule):
    """Implements a variance preserving time schedule as suggested in (Anonymous et al.

    2023) here we have an interpolation cos(t pi/2) x_0 + sin(t pi/2) x_1.
    """

    def calc_mu_sigma(self, x0, x1, t):
        assert not self.is_trajectory
        mu_t = torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
        sigma_t = self.hparams.sigma_min
        return mu_t, sigma_t

    def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
        del x, mu_t, sigma_t
        return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)


class SBCFMLitModule(CFMLitModule):
    """Implements a Schrodinger Bridge based conditional flow matching model.

    This is siexampler to the OTCFM loss, however with the variance varying with t*(1-t). This has
    provably equal probability flow to the Schrodinger bridge solution when the transport is
    computed with the squared Euclidean distance on R^d.
    """

    def calc_mu_sigma(self, x0, x1, t):
        assert not self.is_trajectory
        mu_t = t * x1 + (1 - t) * x0
        sigma_t = self.hparams.sigma_min * torch.sqrt(t - t**2)
        return mu_t, sigma_t

    def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
        del sigma_t
        sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t))
        ut = sigma_t_prime_over_sigma_t * (x - mu_t) + x1 - x0
        return ut


class SF2MLitModule(CFMLitModule):
    def __init__(
        self,
        net: Any,
        optimizer: Any,
        datamodule: LightningDataModule,
        augmentations: AugmentationModule,
        partial_solver: FlowSolver,
        score_net: Optional[Any] = None,
        scheduler: Optional[Any] = None,
        ot_sampler: Optional[Union[str, Any]] = None,
        sigma: Optional[NoiseScheduler] = None,
        sigma_min: float = 0.1,
        outer_loop_epochs: Optional[int] = None,
        score_weight: float = 1.0,
        avg_size: int = -1,
        leaveout_timepoint: int = -1,
        test_nfe: int = 100,
        test_sde: bool = False,
        plot: bool = False,
        nice_name: Optional[str] = "SF2M",
    ) -> None:
        """Initialize a conditional flow matching network either as a generative model or for a
        sequence of timepoints.

        Args:
            net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
            score_net: torch module representing the score function of the flow.
            If not supplied it is assumed that the net contains both flow and
            score.
            optimizer: partial torch.optimizer missing parameters.
            datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
            ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
            sigma: sigma determines the width of the Gaussian smoothing of the data and interpolations.
            leaveout_timepoint: which (if any) timepoint to leave out during the training phase
            plot: if true, log intermediate plots during validation
        """
        super(CFMLitModule, self).__init__()
        self.save_hyperparameters(
            ignore=[
                "net",
                "optimizer",
                "scheduler",
                "datamodule",
                "augmentations",
                "sigma_scheduler",
                "partial_solver",
            ],
            logger=False,
        )
        self.datamodule = datamodule
        self.is_trajectory = False
        if hasattr(datamodule, "IS_TRAJECTORY"):
            self.is_trajectory = datamodule.IS_TRAJECTORY
        # dims is either an integer or a tuple. This helps us to decide whether to process things as
        # a vector or as an image.
        if hasattr(datamodule, "dim"):
            self.dim = datamodule.dim
            self.is_image = False
        elif hasattr(datamodule, "dims"):
            self.dim = datamodule.dims
            self.is_image = True
        else:
            raise NotImplementedError("Datamodule must have either dim or dims")
        self.net = net(dim=self.dim)
        self.separate_score = score_net is not None
        self.score_net = score_net
        if self.separate_score:
            self.score_net = score_net(dim=self.dim)
        self.partial_solver = partial_solver
        self.augmentations = augmentations
        self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
        self.val_augmentations = AugmentationModule(
            # cnf_estimator=None,
            l1_reg=1,
            l2_reg=1,
            squared_l2_reg=1,
        )
        self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.sigma = sigma
        if sigma is None:
            self.sigma = ConstantNoiseScheduler(sigma_min)
        self.ot_sampler = ot_sampler
        if ot_sampler == "None":
            self.ot_sampler = None
        if isinstance(self.ot_sampler, str):
            # regularization taken for optimal Schrodinger bridge relationship
            self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * self.sigma.F(1))
        self.criterion = torch.nn.MSELoss()

        # If we are doing outer loops holds the current dataset
        self.stored_data = None
        self.tmp_stored_data = None

    def calc_mu_sigma(self, x0, x1, t):
        # assert not self.is_trajectory
        ft = self.sigma.F(t)
        fone = self.sigma.F(1)
        mu_t = x0 + (x1 - x0) * ft / fone
        # Note this is slightly different than the notebook. Which is correct?
        sigma_t = torch.sqrt(ft - ft**2 / fone)
        return mu_t, sigma_t

    def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
        ft = self.sigma.F(t)
        fone = self.sigma.F(1)
        sigma_t_prime = self.sigma(t) ** 2 - 2 * ft * self.sigma(t) ** 2 / fone
        sigma_t_prime_over_sigma_t = sigma_t_prime / (sigma_t + 1e-8)
        mu_t_prime = (x1 - x0) * self.sigma(t) ** 2 / fone
        ut = sigma_t_prime_over_sigma_t * (x - mu_t) + mu_t_prime
        return ut

    def calc_loc_and_target(self, x0, x1, t, t_select, training):
        t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
        mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
        eps_t = torch.randn_like(mu_t)
        x = mu_t + sigma_t * eps_t
        ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)

        # if we are starting from right before the leaveout_timepoint then we
        # divide the target by 2
        if training and self.hparams.leaveout_timepoint > 0:
            ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
            t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2

        # p is the pair-wise conditional probability matrix. Note that this has to be torch.cdist(x, mu) in that order
        # t that network sees is incremented by first timepoint
        score_target = eps_t
        # score_target = -eps_t * self.sigma(t_xshape) ** 2 / 2
        t = t + t_select.reshape(-1, *t.shape[1:])
        return x, ut, t, mu_t, sigma_t, score_target

    def forward_flow_and_score(self, t, x):
        if self.separate_score:
            reg, vt = self.augmentations(self.aug_net(t, x, augmented_input=False))
            st = self.score_net(t, x)
            return reg, vt, st
        reg, vtst = self.augmentations(self.aug_net(t, x, augmented_input=False))
        split_idx = vtst.shape[1] // 2
        vt, st = vtst[:, :split_idx], vtst[:, split_idx:]
        return reg, vt, st

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""
        X = self.unpack_batch(batch)
        x0, x1, t_select = self.preprocess_batch(X, training)
        # Either randomly sample a single T or sample a batch of T's
        if self.hparams.avg_size > 0:
            t = torch.rand(1).repeat(X.shape[0]).type_as(X)
        else:
            t = torch.rand(X.shape[0]).type_as(X)
        # Resample the plan if we are using optimal transport
        if self.ot_sampler is not None and self.stored_data is None:
            x0, x1 = self.ot_sampler.sample_plan(x0, x1)
        t_orig = t.clone()

        x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target(
            x0, x1, t, t_select, training
        )

        if self.hparams.avg_size > 0:
            x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)

        reg, vt, st = self.forward_flow_and_score(t, x)
        flow_loss = self.criterion(vt, ut)
        score_loss = self.criterion(
            -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2, score_target
        )
        return torch.mean(reg) + self.hparams.score_weight * score_loss, flow_loss

    def forward_sde_eval(self, ts, x0, x_rest, outputs, prefix):
        # Build a trajectory
        t_span = torch.linspace(0, 1, 2)
        solver = self.partial_solver(
            self.net, self.dim, score_field=self.score_net, sigma=self.sigma
        )
        if False and self.is_image:
            traj = solver.sdeint(x0, t_span, logqp=False)

        trajs = []
        full_trajs = []
        nfe = 0
        kldiv_total = 0
        x0_tmp = x0.clone()
        for i in range(ts - 1):
            traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True)
            kldiv_total += torch.mean(kldiv[-1])
            x0_tmp = traj[-1]
            trajs.append(traj[-1])
            full_trajs.append(traj)
            nfe += solver.nfe
        full_trajs = torch.cat(full_trajs)
        if not self.is_image:
            # Evaluate the fit
            if (
                self.is_trajectory
                and prefix == "test"
                and isinstance(outputs[0]["x"], list)
                and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM")
            ):
                trajs = []
                full_trajs = []
                nfe = 0
                kldiv_total = 0
                x0_tmp = x0.clone()
                for i in range(ts - 1):
                    traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True)
                    x0_tmp = x_rest[i]
                    kldiv_total += torch.mean(kldiv[-1])
                    trajs.append(traj[-1])
                    full_trajs.append(traj)
                    nfe += solver.nfe
                names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1])
            else:
                names, dists = compute_distribution_distances(trajs, x_rest)
            names = [f"{prefix}/sde/{name}" for name in names]
            d = dict(zip(names, dists))
            if self.hparams.leaveout_timepoint >= 0:
                to_add = {
                    f"{prefix}/sde/t_out/{key.split('/')[-1]}": val
                    for key, val in d.items()
                    if key.startswith(f"{prefix}/sde/t{self.hparams.leaveout_timepoint}")
                }
                d.update(to_add)
            d[f"{prefix}/sde/nfe"] = nfe
            d[f"{prefix}/sde/kldiv"] = kldiv_total
            self.log_dict(d, sync_dist=True)
        if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
            solver.augmentations = None
            t_span = torch.linspace(0, 1, 21)
            solver.dt = 0.05
            # solver.dt = 0.01
            traj = solver.sdeint(x0, t_span)
            assert traj.shape[0] == t_span.shape[0]
            kls = [
                self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
            ]
            self.log_dict(
                {f"{prefix}/sde/kl/mean": torch.stack(kls).mean().item()},
                sync_dist=True,
            )
            self.log_dict({f"{prefix}/sde/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
        return trajs, full_trajs

    def eval_epoch_end(self, outputs: List[Any], prefix: str):
        super().eval_epoch_end(outputs, prefix)
        wandb_logger = get_wandb_logger(self.loggers)
        ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix)
        if isinstance(self.dim, int):
            traj, sde_traj = self.forward_sde_eval(ts, x0, x_rest, outputs, prefix)

        if self.hparams.plot:
            if isinstance(self.dim, int):
                plot_trajectory(
                    x,
                    sde_traj,
                    title=f"{self.current_epoch}_sde_traj",
                    key="sde",
                    wandb_logger=wandb_logger,
                )

    def preprocess_batch(self, X, training=False):
        """Converts a batch of data into matched a random pair of (x0, x1)"""
        if self.stored_data is not None and training:
            # Randomly sample a batch from the stored data.
            idx = torch.randint(self.stored_data.shape[0], size=(X.shape[0],))
            X = self.stored_data[idx]
            t_select = torch.zeros(1, device=X.device)
            return X[:, 0], X[:, 1], t_select
        return super().preprocess_batch(X, training)

    def training_step(self, batch: Any, batch_idx: int):
        # If we are doing outerloops we need to resample and store forward and backwards batches.
        if (
            self.hparams.outer_loop_epochs is not None
            and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0
        ):
            X = self.unpack_batch(batch)
            x0, x1, t_select = self.preprocess_batch(X, training=True)
            assert not torch.any(t_select)  # resampling outerloop can only handle 2 timepoints
            solver = self.partial_solver
            t_span = torch.linspace(0, 1, 2)
            solver = self.partial_solver(
                self.net, self.dim, score_field=self.score_net, sigma=self.sigma
            )
            batch_size = x0.shape[0]
            with torch.no_grad():
                forward_traj = solver.sdeint(x0[: batch_size // 2], t_span)
                backward_traj = torch.flip(
                    solver.sdeint(x1[batch_size // 2 :], t_span, reverse=True), (0,)
                )
            stored_traj = torch.cat([forward_traj, backward_traj], dim=1)
            stored_traj = stored_traj.transpose(0, 1)
            if batch_idx == 0:
                self.tmp_stored_data = []
            self.tmp_stored_data.append(stored_traj)
        return super().training_step(batch, batch_idx)

    def training_epoch_end(self, training_step_outputs):
        if (
            self.hparams.outer_loop_epochs is not None
            and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0
        ):
            self.stored_data = torch.cat(self.tmp_stored_data, dim=0).detach().clone()

    def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
        import os
        from math import prod

        from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
        from torchvision.utils import save_image

        solver = self.partial_solver(self.net, self.dim)
        if isinstance(self.hparams.test_nfe, int):
            t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1)
        elif isinstance(self.hparams.test_nfe, str):
            solver.ode_solver = "tsit5"
            t_span = torch.linspace(0, 1, 2).type_as(batch[0])
        else:
            raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}")
        if self.hparams.test_sde:
            solver = self.partial_solver(
                self.net, self.dim, score_field=self.score_net, sigma=self.sigma
            )
            solver.dt = 1 / int(self.hparams.test_nfe)
            t_span = torch.linspace(0, 1, 2).type_as(batch[0])
            integrator = solver.sdeint
        else:
            integrator = solver.odeint
        x0 = torch.randn(5 * batch[0].shape[0], *self.dim).type_as(batch[0])
        traj = integrator(x0, t_span)[-1]
        os.makedirs("images", exist_ok=True)
        mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]]
        std = [255.0 / x for x in [63.0, 62.1, 66.7]]
        inv_normalize = transforms.Compose(
            [
                transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
                transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]),
            ]
        )
        traj = inv_normalize(traj)
        traj = torch.clip(traj, min=0, max=1.0)
        for i, image in enumerate(traj):
            save_image(image, fp=f"images/{batch_idx}_{i}.png")
        os.makedirs("compressed_images", exist_ok=True)
        torch.save(traj.cpu(), f"compressed_images/{batch_idx}.pt")
        return {"x": batch[0]}


class OneWaySF2MLitModule(SF2MLitModule):
    def calc_loc_and_target(self, x0, x1, t, t_select, training):
        x, ut, t, mu_t, sigma_t, score_target = super().calc_loc_and_target(
            x0, x1, t, t_select, training
        )
        t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
        eps_t = -score_target * 2 / (self.sigma(t_xshape) ** 2)
        forward_target = (
            x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t
        )
        return x, forward_target, t, mu_t, sigma_t, None

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""
        X = self.unpack_batch(batch)
        x0, x1, t_select = self.preprocess_batch(X, training)
        # Either randomly sample a single T or sample a batch of T's
        if self.hparams.avg_size > 0:
            t = torch.rand(1).repeat(X.shape[0]).type_as(X)
        else:
            t = torch.rand(X.shape[0]).type_as(X)
        # Resample the plan if we are using optimal transport
        if self.ot_sampler is not None and self.stored_data is None:
            x0, x1 = self.ot_sampler.sample_plan(x0, x1)

        x, forward_target, t, _, _, _ = self.calc_loc_and_target(x0, x1, t, t_select, training)
        t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
        forward_scaling = (1 + self.sigma(t_xshape) ** 2 * t_xshape / (1 - t_xshape + 1e-6)) ** -1
        reg, vt, st = self.forward_flow_and_score(t, x)
        forward_flow_loss = torch.mean(forward_scaling * (vt - forward_target) ** 2)
        return torch.mean(reg), forward_flow_loss

    def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
        # Build a trajectory
        t_span = torch.linspace(0, 1, 101).type_as(x0)
        aug_dims = self.val_augmentations.aug_dims
        regs = []
        trajs = []
        full_trajs = []
        solver = self.partial_solver(
            self.net, self.dim, score_field=self.score_net, sigma=self.sigma
        )
        nfe = 0
        x0_tmp = x0.clone()
        for i in range(ts - 1):
            if not self.is_image:
                solver.augmentations = self.val_augmentations
                traj, aug = solver.sdeint(x0_tmp, t_span + i)
                aug = aug[-1]
                regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
            else:
                traj = solver.sdeint(x0_tmp, t_span + i)
            full_trajs.append(traj)
            traj = traj[-1]
            x0_tmp = traj
            trajs.append(traj)
            nfe += solver.nfe

        if not self.is_image:
            regs = np.stack(regs).mean(axis=0)
            names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
            self.log_dict(dict(zip(names, regs)), sync_dist=True)

            # Evaluate the fit
            names, dists = compute_distribution_distances(trajs, x_rest)
            names = [f"{prefix}/{name}" for name in names]
            d = dict(zip(names, dists))
            if self.hparams.leaveout_timepoint >= 0:
                to_add = {
                    f"{prefix}/t_out/{key.split('/')[-1]}": val
                    for key, val in d.items()
                    if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
                }
                d.update(to_add)
            d[f"{prefix}/nfe"] = nfe
            self.log_dict(d, sync_dist=True)

        if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
            solver.augmentations = None
            t_span = torch.linspace(0, 1, 21)  # 101
            traj = solver.odeint(x0, t_span)
            # t_span = t_span[::5]
            # traj = traj[::5]
            assert traj.shape[0] == t_span.shape[0]
            kls = [
                self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
            ]
            # others = torch.stack([self.datamodule.detailed_evaluation(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)])

            self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
            self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)

        full_trajs = torch.cat(full_trajs)
        return trajs, full_trajs


class DSBMLitModule(SF2MLitModule):
    """Based on SF2M module except directly regresses against the target SDE drift rather than
    separating the ODE and Score components."""

    def calc_loc_and_target(self, x0, x1, t, t_select, training):
        t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))).clone()
        x, ut, t_plus_t_select, mu_t, sigma_t, eps_t = super().calc_loc_and_target(
            x0, x1, t, t_select, training
        )
        forward_target = (
            x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t
        )
        backward_target = (
            x0
            - x1
            - (self.sigma(t_xshape) * torch.sqrt((1 - t_xshape) / (t_xshape + 1e-6))) * eps_t
        )
        return x, forward_target, t_plus_t_select, mu_t, sigma_t, backward_target

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""
        X = self.unpack_batch(batch)
        x0, x1, t_select = self.preprocess_batch(X, training)
        # Either randomly sample a single T or sample a batch of T's
        if self.hparams.avg_size > 0:
            t = torch.rand(1).repeat(X.shape[0]).type_as(X)
        else:
            t = torch.rand(X.shape[0]).type_as(X)
        # Resample the plan if we are using optimal transport
        if self.ot_sampler is not None and self.stored_data is None:
            x0, x1 = self.ot_sampler.sample_plan(x0, x1)

        forward_scaling = (1 + self.sigma(t) ** 2 * t / (1 - t + 1e-6)) ** -1
        backward_scaling = (1 + self.sigma(t) ** 2 * (1 - t) / (t + 1e-6)) ** -1
        x, forward_target, t, _, _, backward_target = self.calc_loc_and_target(
            x0, x1, t, t_select, training
        )
        # print(forward_target, backward_target, x0, x1, t, t_select)
        reg, vt, st = self.forward_flow_and_score(t, x)
        forward_flow_loss = torch.mean(forward_scaling[:, None] * (vt - forward_target) ** 2)
        backward_flow_loss = torch.mean(backward_scaling[:, None] * (st - backward_target) ** 2)
        if not torch.isfinite(forward_flow_loss) or not torch.isfinite(backward_flow_loss):
            raise ValueError("Loss Not Finite")

        return torch.mean(reg) + backward_flow_loss, forward_flow_loss

    def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
        # Build a trajectory
        t_span = torch.linspace(0, 1, 101)
        aug_dims = self.val_augmentations.aug_dims
        regs = []
        trajs = []
        full_trajs = []
        solver = self.partial_solver(
            self.net, self.dim, score_field=self.score_net, sigma=self.sigma
        )
        nfe = 0
        x0_tmp = x0.clone()
        for i in range(ts - 1):
            if not self.is_image:
                solver.augmentations = self.val_augmentations
                traj, aug = solver.odeint(x0_tmp, t_span + i)
            else:
                traj = solver.odeint(x0_tmp, t_span + i)
            full_trajs.append(traj)
            if not self.is_image:
                traj, aug = traj[-1], aug[-1]
            else:
                traj = traj[-1]
                aug = torch.tensor(0.0)
            x0_tmp = traj
            regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
            trajs.append(traj)
            nfe += solver.nfe

        if not self.is_image:
            regs = np.stack(regs).mean(axis=0)
            names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
            self.log_dict(dict(zip(names, regs)), sync_dist=True)

            # Evaluate the fit
            names, dists = compute_distribution_distances(trajs, x_rest)
            names = [f"{prefix}/{name}" for name in names]
            d = dict(zip(names, dists))
            if self.hparams.leaveout_timepoint >= 0:
                to_add = {
                    f"{prefix}/t_out/{key.split('/')[-1]}": val
                    for key, val in d.items()
                    if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
                }
                d.update(to_add)
            d[f"{prefix}/nfe"] = nfe
            self.log_dict(d, sync_dist=True)

        if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
            solver.augmentations = None
            t_span = torch.linspace(0, 1, 21)  # 101
            traj = solver.odeint(x0, t_span)
            # t_span = t_span[::5]
            # traj = traj[::5]
            assert traj.shape[0] == t_span.shape[0]
            kls = [
                self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
            ]
            # others = torch.stack([self.datamodule.detailed_evaluation(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)])

            self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
            self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)

        full_trajs = torch.cat(full_trajs)
        return trajs, full_trajs


class DSBMSharedLitModule(SF2MLitModule):
    """Based on SF2M module except directly regresses against the target SDE drift rather than
    separating the ODE and Score components."""

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""
        X = self.unpack_batch(batch)
        x0, x1, t_select = self.preprocess_batch(X, training)
        # Either randomly sample a single T or sample a batch of T's
        if self.hparams.avg_size > 0:
            t = torch.rand(1).repeat(X.shape[0]).type_as(X)
        else:
            t = torch.rand(X.shape[0]).type_as(X)
        # Resample the plan if we are using optimal transport
        if self.ot_sampler is not None:
            x0, x1 = self.ot_sampler.sample_plan(x0, x1)

        x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target(
            x0, x1, t, t_select, training
        )

        if self.hparams.avg_size > 0:
            x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
        aug_x = self.aug_net(t, x, augmented_input=False)
        reg, vt = self.augmentations(aug_x)
        forward_flow_loss = self.criterion(vt + sigma_t * self.score_net(t, x), ut + score_target)
        backward_flow_loss = self.criterion(
            -vt + sigma_t * self.score_net(t, x), -ut + score_target
        )
        # flow_loss = self.criterion(vt + sigma_t * self.score_net, ut + score_target)
        # score_loss = self.criterion(sigma_t * self.score_net(t, x), score_target)
        return torch.mean(reg) + backward_flow_loss, forward_flow_loss


class FMLitModule(CFMLitModule):
    """Implements a Anonymous et al. 2023 style flow matching loss.

    This maps the standard normal distribution to the data distribution by using conditional flows
    that are the optimal transport flow from a narrow Gaussian around a datapoint to a standard N(x
    | 0, 1).
    """

    def calc_mu_sigma(self, x0, x1, t):
        assert not self.is_trajectory
        del x0
        sigma_min = self.hparams.sigma_min
        mu_t = t * x1
        sigma_t = 1 - (1 - sigma_min) * t
        return mu_t, sigma_t

    def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
        del x0, mu_t, sigma_t
        sigma_min = self.hparams.sigma_min
        ut = (x1 - (1 - sigma_min) * x) / (1 - (1 - sigma_min) * t)
        return ut


class SplineCFMLitModule(CFMLitModule):
    """Implements cubic spline version of OT-CFM."""

    def preprocess_batch(self, X, training=False):
        from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs

        """converts a batch of data into matched a random pair of (x0, x1)"""
        lotp = self.hparams.leaveout_timepoint
        valid_times = torch.arange(X.shape[1]).type_as(X)
        t_select = torch.zeros(1)
        batch_size, times, dim = X.shape
        # TODO handle leaveout case
        if training and self.hparams.leaveout_timepoint > 0:
            # Select random except for the leftout timepoint
            t_select = torch.randint(times - 2, size=(batch_size,))
            X = torch.cat([X[:, :lotp], X[:, lotp + 1 :]], dim=1)
            valid_times = valid_times[valid_times != lotp]
        else:
            t_select = torch.randint(times - 1, size=(batch_size,))
        traj = torch.from_numpy(self.ot_sampler.sample_trajectory(X)).type_as(X)
        x0 = []
        x1 = []
        for i in range(batch_size):
            x0.append(traj[i, t_select[i]])
            x1.append(traj[i, t_select[i] + 1])
        x0, x1 = torch.stack(x0), torch.stack(x1)
        if training and self.hparams.leaveout_timepoint > 0:
            t_select[t_select >= self.hparams.leaveout_timepoint] += 1

        coeffs = natural_cubic_spline_coeffs(valid_times, traj)
        spline = NaturalCubicSpline(coeffs)
        return x0, x1, t_select, spline

    def step(self, batch: Any, training: bool = False):
        """Computes the loss on a batch of data."""
        assert self.is_trajectory
        X = self.unpack_batch(batch)
        x0, x1, t_select, spline = self.preprocess_batch(X, training)

        t = torch.rand(X.shape[0], 1)
        # t [batch, 1]
        # coeffs [batch, times, dims]
        # t that network sees is incremented by first timepoint
        t = t + t_select[:, None]
        ut = torch.stack([spline.derivative(b[0])[i] for i, b in enumerate(t)], dim=0)
        mu_t = torch.stack([spline.evaluate(b[0])[i] for i, b in enumerate(t)], dim=0)
        sigma_t = self.hparams.sigma_min

        # if we are starting from right before the leaveout_timepoint then we
        # divide the target by 2
        if training and self.hparams.leaveout_timepoint > 0:
            ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
            t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2

        x = mu_t + sigma_t * torch.randn_like(x0)
        aug_x = self.aug_net(t, x, augmented_input=False)
        reg, vt = self.augmentations(aug_x)
        return torch.mean(reg), self.criterion(vt, ut)


class CNFLitModule(CFMLitModule):
    def forward_integrate(self, batch: Any, t_span: torch.Tensor):
        """Forward pass with integration over t_span intervals.

        (t, x, t_span) -> [x_t_span].
        """
        return super().forward_integrate(batch, t_span + 1)

    def step(self, batch: Any, training: bool = False):
        obs = self.unpack_batch(batch)
        if not self.is_trajectory:
            obs = obs[:, None, :]
        aug_dims = self.augmentations.aug_dims
        even_ts = torch.arange(obs.shape[1]).to(obs) + 1
        self.prior = MultivariateNormal(
            torch.zeros(self.dim).type_as(obs), torch.eye(self.dim).type_as(obs)
        )
        # Minimize the log likelihood by integrating all back to the initial timepoint
        reversed_ts = torch.cat([torch.flip(even_ts, [0]), torch.tensor([0]).type_as(even_ts)])

        # If only one timepoint then Gaussian is at t0, data t1
        # If multiple timepoints then Gaussian is at t_{-1} data is at times 0 to T
        if self.is_trajectory:
            reversed_ts -= 1
        losses = []
        regs = []
        for t in range(len(reversed_ts) - 1):
            # When leaving out a timepoint simply skip it in the backwards integration
            if self.hparams.leaveout_timepoint == t:
                continue
            ts, x = reversed_ts[t:], obs[:, len(even_ts) - t - 1, :]
            # ts, x = self.aug(reversed_ts[t:], obs[:, len(even_ts) - t - 1, :])
            _, x = self.aug_node(x, ts)
            x = x[-1]
            # Assume log prob is in zero spot
            delta_logprob, reg, x = self.augmentations(x)
            logprob = self.prior.log_prob(x).to(x) - delta_logprob
            losses.append(-torch.mean(logprob))
            # negative because we are integrating backwards
            regs.append(-reg)
            # Predicted locations
        reg = torch.mean(torch.stack(regs))
        loss = torch.mean(torch.stack(losses))
        return reg, loss
