import itertools
import math

pydevd = None
# import pydevd
import torch

mnn = None
import mnn


def ode_forward(
        coefficients: torch.Tensor,
        rhs_equation: torch.Tensor,
        init_vars: torch.Tensor,
        steps: torch.Tensor,
        n_steps: int = None,
        n_init_var_steps: int = None,
        is_step_dim_first: bool = False,
        weight_equation: float = 1.,
        weight_init_var: float = 1.,
        weight_smooth: float = 1.,
        enable_central_smoothness: bool = True,
        enable_cuda_graph: bool = True,
        enable_freeze_lhs: bool = False,
) -> torch.Tensor:
    """
    coefficients: (..., n_steps[b], n_equations, n_dims, n_orders)
    rhs_equation: (..., n_steps[b], n_equations[e])
    init_vars: (..., n_init_var_steps[b], n_dims[e], n_init_var_orders[e])
    steps: (..., n_steps-1[b])
    return: (..., n_steps, n_dims, n_orders)
    """

    assert not enable_freeze_lhs or enable_cuda_graph
    # enable_freeze_lhs is currently only implemented for enable_cuda_graph

    LinearSolver.enable_cuda_graph = enable_cuda_graph
    if not enable_cuda_graph:
        LinearSolver.is_graph_initialized = False

    if not is_step_dim_first:
        coefficients: torch.Tensor = move_step_dim(coefficients, 4, False)
        rhs_equation: torch.Tensor = move_step_dim(rhs_equation, 2, False)
        init_vars: torch.Tensor = move_step_dim(init_vars, 3, False)
        steps: torch.Tensor = move_step_dim(steps, 1, False)

    dtype: torch.dtype = coefficients.dtype
    device: torch.device = coefficients.device

    n_steps: int = steps.size(0) + 1 if n_steps is None else n_steps
    assert n_steps >= 2
    n_init_var_steps: int = init_vars.size(0) if n_init_var_steps is None else n_init_var_steps

    n_steps_coefficients, *batch_coefficients, n_equations, n_dims, n_orders = coefficients.shape
    assert n_steps_coefficients in [n_steps, 1]
    n_steps_rhs_equation, *batch_rhs_equation, n_equations_rhs_equation = rhs_equation.shape
    assert n_steps_rhs_equation in [n_steps, 1] and n_equations_rhs_equation == n_equations
    n_init_var_steps_rhs_init, *batch_rhs_init, n_dims_rhs_init, n_init_var_orders = init_vars.shape
    assert n_init_var_steps_rhs_init in [n_init_var_steps, 1] and n_dims_rhs_init == n_dims
    n_steps_steps, *batch_steps = steps.shape
    assert n_steps_steps in [n_steps - 1, 1]
    batch_lhs: torch.Size = torch.broadcast_shapes(batch_coefficients, batch_steps)
    batch: torch.Size = torch.broadcast_shapes(batch_lhs, batch_rhs_equation, batch_rhs_init)

    if not (enable_freeze_lhs and hasattr(LinearSolver, 'block_diag_0')):
        block_diag_0, block_diag_1, block_diag_2, beta = compute_ata(
            coefficients,
            rhs_equation,
            init_vars,
            steps,
            n_steps,
            n_init_var_steps,
            weight_equation,
            weight_init_var,
            weight_smooth,
            enable_central_smoothness,
            dtype,
            device,
            n_dims,
            n_orders,
            n_init_var_orders,
            batch,
            batch_lhs,
        )
        x: torch.Tensor = LinearSolver.apply(
            block_diag_0,
            block_diag_1,
            block_diag_2,
            beta,
        )  # (n_steps, ..., n_dims * n_orders, 1)
    else:
        beta: torch.Tensor = compute_atb(
            coefficients,
            rhs_equation,
            init_vars,
            n_steps,
            n_init_var_steps,
            weight_equation,
            weight_init_var,
            dtype,
            device,
            n_orders,
            n_init_var_orders,
            batch,
        )
        x: torch.Tensor = LinearSolver.apply(None, None, None, beta)  # (n_steps, ..., n_dims * n_orders, 1)

    x: torch.Tensor = x.reshape(n_steps, *batch, n_dims, n_orders)  # (n_steps, ..., n_dims, n_orders)

    if not is_step_dim_first:
        x: torch.Tensor = move_step_dim(x, 3, True)

    return x


def move_step_dim(x: torch.Tensor, i: int, revert: bool) -> torch.Tensor:
    n_tensor_dims: int = x.dim()
    if not revert:
        dim_order: list[int] = [
            n_tensor_dims - i,
            *range(n_tensor_dims - i),
            *range(n_tensor_dims - i + 1, n_tensor_dims),
        ]
    else:
        dim_order: list[int] = [*range(1, n_tensor_dims - i + 1), 0, *range(n_tensor_dims - i + 1, n_tensor_dims)]
    return x.permute(dim_order)#.contiguous()


def compute_ata(
        coefficients: torch.Tensor,
        rhs_equation: torch.Tensor,
        init_vars: torch.Tensor,
        steps: torch.Tensor,
        n_steps: int,
        n_init_var_steps: int,
        weight_equation: float,
        weight_init_var: float,
        weight_smooth: float,
        enable_central_smoothness: bool,
        dtype: torch.dtype,
        device: torch.device,
        n_dims: int,
        n_orders: int,
        n_init_var_orders: int,
        batch: torch.Size,
        batch_lhs: torch.Size,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    coefficients: (n_steps[b], ..., n_equations, n_dims, n_orders)
    rhs_equation: (n_steps[b], ..., n_equations[e])
    init_vars: (n_init_var_steps[b], ..., n_dims[e], n_init_var_orders[e])
    steps: (n_steps-1[b], ...)
    return: (n_steps, ..., n_dims * n_orders, n_dims * n_orders),
            (n_steps-1, ..., n_dims * n_orders, n_dims * n_orders),
            (n_steps-2, ..., n_dims * n_orders, n_dims * n_orders),
            (n_steps, ..., n_dims * n_orders, 1),
    """

    # ode equation constraints
    c: torch.Tensor = coefficients.flatten(start_dim=-2)  # (n_steps[b], ..., n_equations, n_dims * n_orders)
    ct: torch.Tensor = c.transpose(-2, -1) * weight_equation ** 2  # (n_steps[b], ..., n_dims * n_orders, n_equations)
    block_diag_0: torch.Tensor = ct @ c  # (n_steps[b], ..., n_dims * n_orders, n_dims * n_orders)
    beta: torch.Tensor = ct @ rhs_equation[..., None]  # (n_steps[b], ..., n_dims * n_orders, 1)

    block_diag_0: torch.Tensor = block_diag_0.repeat(
        n_steps // block_diag_0.size(0),
        *[ss // s for ss, s in zip(batch_lhs, block_diag_0.shape[1:-2])],
        1,
        1,
    )  # (n_steps, ..., n_dims * n_orders, n_dims * n_orders)
    beta: torch.Tensor = beta.repeat(
        n_steps // beta.size(0),
        *[ss // s for ss, s in zip(batch, beta.shape[1:-2])],
        1,
        1,
    )  # (n_steps, ..., n_dims * n_orders, 1)

    # initial-value constraints
    weight2_init_var: float = weight_init_var ** 2
    init_idx: torch.Tensor = torch.arange(n_init_var_orders, device=device).repeat(n_dims) \
                             + (n_orders * torch.arange(n_dims, device=device)).repeat_interleave(n_init_var_orders)
    # (n_dims * n_init_var_orders)
    block_diag_0[:n_init_var_steps, ..., init_idx, init_idx] += weight2_init_var
    beta[:n_init_var_steps, ..., :, 0] += torch.cat([
        init_vars * weight2_init_var,
        torch.zeros(*init_vars.shape[:-1], n_orders - n_init_var_orders, dtype=dtype, device=device),
    ], dim=-1).flatten(start_dim=-2)

    # smoothness constraints (forward & backward)
    order_idx: torch.Tensor = torch.arange(n_orders, device=device)  # (n_orders)
    sign_vec: torch.Tensor = order_idx % 2 * (-2) + 1  # (n_orders)
    sign_map: torch.Tensor = sign_vec * sign_vec[:, None]  # (n_orders, n_orders)

    expansions: torch.Tensor = steps[..., None] ** order_idx * weight_smooth  # (n_steps-1[b], ..., n_orders)
    et_e_diag: torch.Tensor = expansions ** 2  # (n_steps-1[b], ..., n_orders)
    e_outer: torch.Tensor = expansions[..., None] * expansions[..., None, :]  # (n_steps-1[b], ..., n_orders, n_orders)
    factorials: torch.Tensor = (-(order_idx - order_idx[:, None] + 1).triu().to(dtype=dtype).lgamma()).exp()
    # (n_orders, n_orders)
    if enable_central_smoothness:
        et_e_diag[..., -1] = 0.
        factorials[-1, -1] = 0.
    et_ft_f_e: torch.Tensor = e_outer * (factorials.t() @ factorials)  # (n_steps-1[b], ..., n_orders, n_orders)

    smooth_block_diag_1: torch.Tensor = e_outer * -(factorials + factorials.transpose(-2, -1) * sign_map)
    # (n_steps-1[b], ..., n_orders, n_orders)
    smooth_block_diag_0: torch.Tensor = torch.zeros(n_steps, *batch_lhs, n_orders, n_orders, dtype=dtype, device=device)
    # (n_steps, ..., n_orders, n_orders)
    smooth_block_diag_0[:-1] += et_ft_f_e
    smooth_block_diag_0[1:] += et_ft_f_e * sign_map
    smooth_block_diag_0[:-1, ..., order_idx, order_idx] += et_e_diag
    smooth_block_diag_0[1:, ..., order_idx, order_idx] += et_e_diag

    smooth_block_diag_1: torch.Tensor = smooth_block_diag_1.repeat(
        (n_steps - 1) // smooth_block_diag_1.size(0),
        *([1] * len(batch_lhs)),
        1,
        1,
    )  # (n_steps-1, ..., n_orders, n_orders)
    block_diag_1: torch.Tensor = torch.zeros(
        n_steps - 1, *batch_lhs, n_dims * n_orders, n_dims * n_orders, dtype=dtype, device=device,
    )  # (n_steps-1, ..., n_dims * n_orders, n_dims * n_orders)

    if enable_central_smoothness:
        steps: torch.Tensor = steps.repeat((n_steps - 1) // steps.size(0), *([1] * len(batch_lhs)))  # (n_steps-1, ...)

        # smoothness constraints (central)
        steps2: torch.Tensor = steps[:-1] + steps[1:]  # (n_steps-2, ...)
        weight2_smooth: float = weight_smooth ** 2
        steps26: torch.Tensor = steps2 ** (n_orders * 2 - 6) * weight2_smooth  # (n_steps-2, ...)
        steps25: torch.Tensor = steps2 ** (n_orders * 2 - 5) * weight2_smooth  # (n_steps-2, ...)
        steps24: torch.Tensor = steps2 ** (n_orders * 2 - 4) * weight2_smooth  # (n_steps-2, ...)

        smooth_block_diag_0[:-2, ..., n_orders - 2, n_orders - 2] += steps26
        smooth_block_diag_0[2:, ..., n_orders - 2, n_orders - 2] += steps26
        smooth_block_diag_0[1:-1, ..., n_orders - 1, n_orders - 1] += steps24
        smooth_block_diag_1[:-1, ..., n_orders - 1, n_orders - 2] += steps25
        smooth_block_diag_1[1:, ..., n_orders - 2, n_orders - 1] -= steps25
        smooth_block_diag_2: torch.Tensor | None = torch.zeros(
            n_steps - 2, *batch_lhs, n_orders, n_orders, dtype=dtype, device=device,
        )  # (n_steps-2, ..., n_orders, n_orders)
        smooth_block_diag_2[..., n_orders - 2, n_orders - 2] = -steps26

        block_diag_2: torch.Tensor | None = torch.zeros(
            n_steps - 2, *batch_lhs, n_dims * n_orders, n_dims * n_orders, dtype=dtype, device=device,
        )  # (n_steps-2, ..., n_dims * n_orders, n_dims * n_orders
    else:
        smooth_block_diag_2 = None
        block_diag_2 = None

    # copy to n_dims
    for dim in range(n_dims):
        i1: int = dim * n_orders
        i2: int = (dim + 1) * n_orders
        block_diag_0[..., i1:i2, i1:i2] += smooth_block_diag_0
        block_diag_1[..., i1:i2, i1:i2] = smooth_block_diag_1
        if block_diag_2 is not None:
            block_diag_2[..., i1:i2, i1:i2] = smooth_block_diag_2

    return block_diag_0, block_diag_1, block_diag_2, beta


def compute_atb(
        coefficients: torch.Tensor,
        rhs_equation: torch.Tensor,
        init_vars: torch.Tensor,
        n_steps: int,
        n_init_var_steps: int,
        weight_equation: float,
        weight_init_var: float,
        dtype: torch.dtype,
        device: torch.device,
        n_orders: int,
        n_init_var_orders: int,
        batch: torch.Size,
) -> torch.Tensor:
    """
    coefficients: (n_steps[b], ..., n_equations, n_dims, n_orders)
    rhs_equation: (n_steps[b], ..., n_equations[e])
    init_vars: (n_init_var_steps[b], ..., n_dims[e], n_init_var_orders[e])
    steps: (n_steps-1[b], ...)
    return: (n_steps, ..., n_dims * n_orders, 1)
    """

    beta: torch.Tensor = coefficients.flatten(start_dim=-2).transpose(-2, -1) @ (
        rhs_equation[..., None] * weight_equation ** 2
    )  # (n_steps[b], ..., n_dims * n_orders, 1)
    beta: torch.Tensor = beta.repeat(
        n_steps // beta.size(0),
        *[ss // s for ss, s in zip(batch, beta.shape[1:-2])],
        1,
        1,
    )  # (n_steps, ..., n_dims * n_orders, 1)
    beta[:n_init_var_steps, ..., :, 0] += torch.cat([
        init_vars * weight_init_var ** 2,
        torch.zeros(*init_vars.shape[:-1], n_orders - n_init_var_orders, dtype=dtype, device=device),
    ], dim=-1).flatten(start_dim=-2)
    return beta


class LinearSolver(torch.autograd.Function):
    enable_ldl: bool = True
    enable_cuda_graph: bool = True
    is_graph_initialized: bool = False
    n_warmups: int = 10
    enable_freeze_lhs: bool = False
    graph_cholesky: torch.cuda.CUDAGraph
    graph_substitution: torch.cuda.CUDAGraph
    block_diag_0: torch.Tensor
    block_diag_1: torch.Tensor
    block_diag_2: torch.Tensor
    rhs: torch.Tensor
    tmp_info: torch.Tensor

    @staticmethod
    def forward(
            ctx,
            block_diag_0: torch.Tensor,
            block_diag_1: torch.Tensor,
            block_diag_2: torch.Tensor | None,
            rhs: torch.Tensor,
    ) -> torch.Tensor:
        """
        block_diag_0: (n_steps, ..., n_dims * n_orders, n_dims * n_orders)
        block_diag_1: (n_steps-1, ..., n_dims * n_orders, n_dims * n_orders)
        block_diag_2: (n_steps-2, ..., n_dims * n_orders, n_dims * n_orders)
        rhs: (n_steps, ..., n_dims * n_orders, 1)
        """

        solver = mnn if mnn is not None else LinearSolver
        device: torch.device = rhs.device

        if LinearSolver.enable_cuda_graph and device.type == 'cuda':
            if not LinearSolver.is_graph_initialized:
                torch.cuda.set_device(device)

                LinearSolver.graph_cholesky = torch.cuda.CUDAGraph()
                LinearSolver.graph_substitution = torch.cuda.CUDAGraph()

                tmp_info: torch.Tensor = torch.empty(
                    block_diag_0.shape[1:-2], dtype=torch.int32, device=device,
                )  # (...)

                LinearSolver.block_diag_0 = torch.empty_like(block_diag_0)
                LinearSolver.block_diag_1 = torch.empty_like(block_diag_1)
                LinearSolver.block_diag_2 = torch.empty_like(block_diag_2) if block_diag_2 is not None else None
                LinearSolver.rhs = torch.empty_like(rhs)
                LinearSolver.tmp_info = tmp_info

                s = torch.cuda.Stream()
                s.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(s):
                    for _ in range(LinearSolver.n_warmups):
                        solver.cholesky_inplace(
                            LinearSolver.block_diag_0,
                            LinearSolver.block_diag_1,
                            LinearSolver.block_diag_2,
                            LinearSolver.tmp_info,
                            enable_ldl=LinearSolver.enable_ldl,
                        )
                torch.cuda.current_stream().wait_stream(s)

                with torch.cuda.graph(LinearSolver.graph_cholesky):
                    solver.cholesky_inplace(
                        LinearSolver.block_diag_0,
                        LinearSolver.block_diag_1,
                        LinearSolver.block_diag_2,
                        LinearSolver.tmp_info,
                        enable_ldl=LinearSolver.enable_ldl,
                    )

                s = torch.cuda.Stream()
                s.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(s):
                    for _ in range(LinearSolver.n_warmups):
                        solver.substitution_inplace(
                            LinearSolver.block_diag_0,
                            LinearSolver.block_diag_1,
                            LinearSolver.block_diag_2,
                            LinearSolver.rhs,
                            enable_ldl=LinearSolver.enable_ldl,
                        )
                torch.cuda.current_stream().wait_stream(s)

                with torch.cuda.graph(LinearSolver.graph_substitution):
                    solver.substitution_inplace(
                        LinearSolver.block_diag_0,
                        LinearSolver.block_diag_1,
                        LinearSolver.block_diag_2,
                        LinearSolver.rhs,
                        enable_ldl=LinearSolver.enable_ldl,
                    )

                torch.cuda.synchronize(device)
                LinearSolver.is_graph_initialized = True

            if block_diag_0 is not None:
                LinearSolver.block_diag_0.copy_(block_diag_0)
                LinearSolver.block_diag_1.copy_(block_diag_1)
                if LinearSolver.block_diag_2 is not None:
                    LinearSolver.block_diag_2.copy_(block_diag_2)
                LinearSolver.graph_cholesky.replay()
            else:
                LinearSolver.enable_freeze_lhs = True

            LinearSolver.rhs.copy_(rhs)
            LinearSolver.graph_substitution.replay()
            rhs = LinearSolver.rhs

        else:
            tmp_info: torch.Tensor = torch.empty(
                block_diag_0.shape[1:-2], dtype=torch.int32, device=device,
            )  # (...)
            solver.cholesky_inplace(
                block_diag_0,
                block_diag_1,
                block_diag_2,
                tmp_info,
                enable_ldl=LinearSolver.enable_ldl,
            )
            solver.substitution_inplace(
                block_diag_0,
                block_diag_1,
                block_diag_2,
                rhs,
                enable_ldl=LinearSolver.enable_ldl,
            )
            ctx.block_diag_0, ctx.block_diag_1, ctx.block_diag_2, ctx.x = block_diag_0, block_diag_1, block_diag_2, rhs

        return rhs.clone()

    @staticmethod
    def backward(
            ctx,
            rhs: torch.Tensor,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor]:
        """
        dx: (n_steps, ..., n_dims * n_orders, 1)
        """

        if pydevd:
            pydevd.settrace(suspend=False, trace_only_current_thread=True)
            # https://discuss.pytorch.org/t/custom-backward-breakpoint-doesnt-get-hit/6473/15

        if LinearSolver.enable_cuda_graph and rhs.device.type == 'cuda':
            x: torch.Tensor = -LinearSolver.rhs  # (n_steps, ..., n_dims * n_orders, 1)
            LinearSolver.rhs.copy_(rhs)
            LinearSolver.graph_substitution.replay()
            rhs = LinearSolver.rhs
        else:
            x: torch.Tensor = -ctx.x  # (n_steps, ..., n_dims * n_orders, 1)
            rhs: torch.Tensor = rhs.clone()
            solver = mnn if mnn is not None else LinearSolver
            solver.substitution_inplace(
                ctx.block_diag_0, ctx.block_diag_1, ctx.block_diag_2, rhs, enable_ldl=LinearSolver.enable_ldl,
            )

        if not LinearSolver.enable_freeze_lhs:
            da0: torch.Tensor | None = rhs * x[..., None, :, 0]
            da1: torch.Tensor | None = rhs[1:] * x[:-1, ..., None, :, 0] + x[1:] * rhs[:-1, ..., None, :, 0]
            if LinearSolver.enable_cuda_graph and LinearSolver.block_diag_2 is not None \
                    or not LinearSolver.enable_cuda_graph and ctx.block_diag_2 is not None:
                da2: torch.Tensor | None = rhs[2:] * x[:-2, ..., None, :, 0] + x[2:] * rhs[:-2, ..., None, :, 0]
            else:
                da2 = None
        else:
            da0 = da1 = da2 = None

        return da0, da1, da2, rhs

    @staticmethod
    def cholesky_inplace(
        block_diag_0: torch.Tensor | list[torch.Tensor],
        block_diag_1: torch.Tensor | list[torch.Tensor],
        block_diag_2: torch.Tensor | list[torch.Tensor] | None,
        tmp_info: torch.Tensor,
        enable_ldl: bool = True,
    ) -> None:
        """
        block_diag_0: (n_steps, ..., n_dims * n_orders, n_dims * n_orders)
        block_diag_1: (n_steps-1, ..., n_dims * n_orders, n_dims * n_orders)
        block_diag_2: (n_steps-2, ..., n_dims * n_orders, n_dims * n_orders)
        tmp_info: (...)
        """

        enable_block_diag_2: bool = block_diag_2 is not None
        n_steps: int = len(block_diag_0)
        for step in range(n_steps):
            if enable_block_diag_2 and step >= 2:
                torch.linalg.solve_triangular(
                    block_diag_0[step - 2].transpose(-2, -1),
                    block_diag_2[step - 2],
                    upper=True,
                    left=False,
                    unitriangular=False,
                    out=block_diag_2[step - 2],
                )  # block_diag_2[step - 2] @= block_diag_0[step - 2].t().inv()
                LinearSolver.bsubbmm(
                    block_diag_1[step - 1],
                    block_diag_2[step - 2],
                    block_diag_1[step - 2].transpose(-2, -1),
                )  # block_diag_1[step - 1] -= block_diag_2[step - 2] @ block_diag_1[step - 2].t()
            if step >= 1:
                torch.linalg.solve_triangular(
                    block_diag_0[step - 1].transpose(-2, -1),
                    block_diag_1[step - 1],
                    upper=True,
                    left=False,
                    unitriangular=False,
                    out=block_diag_1[step - 1],
                )  # block_diag_1[step - 1] @= block_diag_0[step - 1].t().inv()
                if enable_block_diag_2 and step >= 2:
                    LinearSolver.bsubbmm(
                        block_diag_0[step],
                        block_diag_2[step - 2],
                        block_diag_2[step - 2].transpose(-2, -1),
                    )  # block_diag_0[step] -= block_diag_2[step - 2] @ block_diag_2[step - 2].t()
                LinearSolver.bsubbmm(
                    block_diag_0[step],
                    block_diag_1[step - 1],
                    block_diag_1[step - 1].transpose(-2, -1),
                )  # block_diag_0[step] -= block_diag_1[step - 1] @ block_diag_1[step - 1].t()
            torch.linalg.cholesky_ex(
                block_diag_0[step],
                upper=False,
                check_errors=False,
                out=(block_diag_0[step], tmp_info),
            )

        if enable_ldl:
            # LDL decomposition https://en.wikipedia.org/wiki/Cholesky_decomposition#Block_variant
            # block_diag_0: Cholesky decomposition of D
            # block_diag_1: L blocks
            # block_diag_2: L blocks
            torch.linalg.solve_triangular(
                block_diag_0[:-1],
                block_diag_1,
                upper=False,
                left=False,
                unitriangular=False,
                out=block_diag_1,
            )  # block_diag_1 @= block_diag_0[:-1].inv()
            if enable_block_diag_2:
                torch.linalg.solve_triangular(
                    block_diag_0[:-2],
                    block_diag_2,
                    upper=False,
                    left=False,
                    unitriangular=False,
                    out=block_diag_2,
                )  # block_diag_2 @= block_diag_0[:-2].inv()

    @staticmethod
    def substitution_inplace(
            block_diag_0: torch.Tensor | list[torch.Tensor],
            block_diag_1: torch.Tensor | list[torch.Tensor],
            block_diag_2: torch.Tensor | list[torch.Tensor] | None,
            rhs: torch.Tensor | list[torch.Tensor],
            enable_ldl: bool = True,
    ) -> None:
        """
        block_diag_0: (n_steps, ..., n_dims * n_orders, n_dims * n_orders)
        block_diag_1: (n_steps-1, ..., n_dims * n_orders, n_dims * n_orders)
        block_diag_2: (n_steps-2, ..., n_dims * n_orders, n_dims * n_orders)
        rhs: (n_steps, ..., n_dims * n_orders, 1)
        """

        enable_block_diag_2: bool = block_diag_2 is not None
        n_steps: int = len(block_diag_0)

        # A X = B => L (D (Lt X)) = B
        for step in range(n_steps):
            # solve L Z = B, block forward substitution
            if enable_block_diag_2 and step >= 2:
                LinearSolver.bsubbmm(
                    rhs[step],
                    block_diag_2[step - 2],
                    rhs[step - 2],
                )  # rhs[step] -= block_diag_2[step - 2] @ rhs[step - 2]
            if step >= 1:
                LinearSolver.bsubbmm(
                    rhs[step],
                    block_diag_1[step - 1],
                    rhs[step - 1],
                )  # rhs[step] -= block_diag_1[step - 1] @ rhs[step - 1]
            if not enable_ldl:
                torch.linalg.solve_triangular(
                    block_diag_0[step],
                    rhs[step],
                    upper=False,
                    left=True,
                    unitriangular=False,
                    out=rhs[step],
                )  # rhs[step] = block_diag_0[step].inv() @ rhs[step]
        if enable_ldl:
            # solve D Y = Z, block forward substitution
            # torch.cholesky_solve(
            #     rhs,
            #     block_diag_0,
            #     upper=False,
            #     out=rhs,
            # )  # rhs = (block_diag_0 @ block_diag_0.t()).inv() @ rhs
            # This is slow :(
            torch.linalg.solve_triangular(
                block_diag_0,
                rhs,
                upper=False,
                left=True,
                unitriangular=False,
                out=rhs,
            )  # rhs = block_diag_0.inv() @ rhs
            torch.linalg.solve_triangular(
                block_diag_0.transpose(-2, -1),
                rhs,
                upper=True,
                left=True,
                unitriangular=False,
                out=rhs,
            )  # rhs = block_diag_0.t().inv() @ rhs
        for step in range(n_steps - 1, -1, -1):
            # solve Lt X = Y, block backward substitution
            if enable_block_diag_2 and step < n_steps - 2:
                LinearSolver.bsubbmm(
                    rhs[step],
                    block_diag_2[step].transpose(-2, -1),
                    rhs[step + 2],
                )  # rhs[step] -= block_diag_2[step].t() @ rhs[step + 2]
            if step < n_steps - 1:
                LinearSolver.bsubbmm(
                    rhs[step],
                    block_diag_1[step].transpose(-2, -1),
                    rhs[step + 1],
                )  # rhs[step] -= block_diag_1[step].t() @ rhs[step + 1]
            if not enable_ldl:
                torch.linalg.solve_triangular(
                    block_diag_0[step].transpose(-2, -1),
                    rhs[step],
                    upper=True,
                    left=True,
                    unitriangular=False,
                    out=rhs[step],
                )  # rhs[step] = block_diag_0[step].t().inv() @ rhs[step]

    @staticmethod
    def bsubbmm(c: torch.Tensor, a: torch.Tensor, b: torch.Tensor) -> None:
        """
        c -= a @ b
        """
        c -= a @ b
        # not yet supporting multiple batch dims
        # a, b, c = a.flatten(end_dim=-3), b.flatten(end_dim=-3), c.flatten(end_dim=-3)
        # torch.baddbmm(c, a, b, beta=1, alpha=-1, out=c)


def ode_forward_lite(
        coefficients: torch.Tensor,
        rhs_equation: torch.Tensor,
        rhs_init: torch.Tensor,
        steps: torch.Tensor,
        n_steps: int = None,
        n_init_var_steps: int = None,
) -> torch.Tensor:
    """
    coefficients: (..., n_steps[b], n_equations, n_dims, n_orders)
    rhs_equation: (..., n_steps[b], n_equations[e])
    rhs_init: (..., n_init_var_steps[b], n_dims[e], n_init_var_orders[e])
    steps: (..., n_steps-1[b])
    return: (..., n_steps, n_dims, n_orders)
    """

    dtype: torch.dtype = coefficients.dtype
    device: torch.device = coefficients.device

    n_steps: int = steps.size(-1) + 1 if n_steps is None else n_steps
    assert n_steps >= 2
    n_init_var_steps: int = rhs_init.size(-3) if n_init_var_steps is None else n_init_var_steps

    *batch_coefficients, n_steps_coefficients, n_equations, n_dims, n_orders = coefficients.shape
    assert n_steps_coefficients in [n_steps, 1]
    *batch_rhs_equation, n_steps_rhs_equation, n_equations_rhs_equation = rhs_equation.shape
    assert n_steps_rhs_equation in [n_steps, 1] and n_equations_rhs_equation == n_equations
    *batch_rhs_init, n_init_var_steps_rhs_init, n_dims_rhs_init, n_init_var_orders = rhs_init.shape
    assert n_init_var_steps_rhs_init in [n_init_var_steps, 1] and n_dims_rhs_init == n_dims
    *batch_steps, n_steps_steps = steps.shape
    assert n_steps_steps in [n_steps - 1, 1]
    batch_lhs: torch.Size = torch.broadcast_shapes(batch_coefficients, batch_steps)
    batch: torch.Size = torch.broadcast_shapes(batch_lhs, batch_rhs_equation, batch_rhs_init)

    # ode equation constraints
    c: torch.Tensor = coefficients.flatten(start_dim=-2)  # (..., n_steps[b], n_equations, n_dims * n_orders)
    ct: torch.Tensor = c.transpose(-2, -1)  # (..., n_steps[b], n_dims * n_orders, n_equations)
    block_diag_0: torch.Tensor = ct @ c  # (..., n_steps[b], n_dims * n_orders, n_dims * n_orders)
    beta: torch.Tensor = ct @ rhs_equation[..., None]  # (..., n_steps[b], n_dims * n_orders, 1)

    block_diag_0: torch.Tensor = block_diag_0.repeat(
        *[ss // s for ss, s in zip(batch_lhs, block_diag_0.shape[:-3])],
        n_steps // block_diag_0.size(-3),
        1,
        1,
    )  # (..., n_steps, n_dims * n_orders, n_dims * n_orders)
    beta: torch.Tensor = beta.repeat(
        *[ss // s for ss, s in zip(batch, beta.shape[:-3])],
        n_steps // beta.size(-3),
        1,
        1,
    )  # (..., n_steps, n_dims * n_orders, 1)

    # initial-value constraints
    init_idx: torch.Tensor = torch.arange(n_init_var_orders, device=device).repeat(n_dims) \
                             + (n_orders * torch.arange(n_dims, device=device)).repeat_interleave(n_init_var_orders)
    # (n_dims * n_init_var_orders)
    block_diag_0[..., :n_init_var_steps, init_idx, init_idx] += 1.
    beta[..., :n_init_var_steps, :, 0] += torch.cat([
        rhs_init,
        torch.zeros(*rhs_init.shape[:-1], n_orders - n_init_var_orders, dtype=dtype, device=device),
    ], dim=-1).flatten(start_dim=-2)

    # smoothness constraints (forward & backward)
    order_idx: torch.Tensor = torch.arange(n_orders, device=device)  # (n_orders)
    sign_vec: torch.Tensor = order_idx % 2 * (-2) + 1  # (n_orders)
    sign_map: torch.Tensor = sign_vec * sign_vec[:, None]  # (n_orders, n_orders)

    expansions: torch.Tensor = steps[..., None] ** order_idx  # (..., n_steps-1[b], n_orders)
    et_e_diag: torch.Tensor = expansions ** 2  # (..., n_steps-1[b], n_orders)
    et_e_diag[..., -1] = 0.
    factorials: torch.Tensor = (-(order_idx - order_idx[:, None] + 1).triu().to(dtype=dtype).lgamma()).exp()
    # (n_orders, n_orders)
    factorials[-1, -1] = 0.
    e_outer: torch.Tensor = expansions[..., None] * expansions[..., None, :]  # (..., n_steps-1[b], n_orders, n_orders)
    et_ft_f_e: torch.Tensor = e_outer * (factorials.t() @ factorials)  # (..., n_steps-1[b], n_orders, n_orders)

    smooth_block_diag_1: torch.Tensor = e_outer * -(factorials + factorials.transpose(-2, -1) * sign_map)
    # (..., n_steps-1[b], n_orders, n_orders)
    smooth_block_diag_0: torch.Tensor = torch.zeros(*batch_lhs, n_steps, n_orders, n_orders, dtype=dtype, device=device)
    # (..., n_steps, n_orders, n_orders)
    smooth_block_diag_0[..., :-1, :, :] += et_ft_f_e
    smooth_block_diag_0[..., 1:, :, :] += et_ft_f_e * sign_map
    smooth_block_diag_0[..., :-1, order_idx, order_idx] += et_e_diag
    smooth_block_diag_0[..., 1:, order_idx, order_idx] += et_e_diag

    smooth_block_diag_1: torch.Tensor = smooth_block_diag_1.repeat(
        *([1] * len(batch_lhs)),
        (n_steps - 1) // smooth_block_diag_1.size(-3),
        1,
        1,
    )  # (..., n_steps-1, n_orders, n_orders)
    steps: torch.Tensor = steps.repeat(*([1] * len(batch_lhs)), (n_steps - 1) // steps.size(-1))  # (..., n_steps-1)

    # smoothness constraints (central)
    steps2: torch.Tensor = steps[..., :-1] + steps[..., 1:]  # (..., n_steps-2)
    steps26: torch.Tensor = steps2 ** (n_orders * 2 - 6)  # (..., n_steps-2)
    steps25: torch.Tensor = steps2 ** (n_orders * 2 - 5)  # (..., n_steps-2)
    steps24: torch.Tensor = steps2 ** (n_orders * 2 - 4)  # (..., n_steps-2)

    smooth_block_diag_0[..., :-2, n_orders - 2, n_orders - 2] += steps26
    smooth_block_diag_0[..., 2:, n_orders - 2, n_orders - 2] += steps26
    smooth_block_diag_0[..., 1:-1, n_orders - 1, n_orders - 1] += steps24
    smooth_block_diag_1[..., :-1, n_orders - 1, n_orders - 2] += steps25
    smooth_block_diag_1[..., 1:, n_orders - 2, n_orders - 1] -= steps25
    smooth_block_diag_2: torch.Tensor = torch.zeros(
        *batch_lhs, n_steps - 2, n_orders, n_orders, dtype=dtype, device=device,
    )  # (..., n_steps-2, n_orders, n_orders)
    smooth_block_diag_2[..., n_orders - 2, n_orders - 2] = -steps26

    # copy to n_dims
    block_diag_1: torch.Tensor = torch.zeros(
        *batch_lhs, n_steps - 1, n_dims * n_orders, n_dims * n_orders, dtype=dtype, device=device,
    )  # (..., n_steps-1, n_dims * n_orders, n_dims * n_orders)
    block_diag_2: torch.Tensor = torch.zeros(
        *batch_lhs, n_steps - 2, n_dims * n_orders, n_dims * n_orders, dtype=dtype, device=device,
    )  # (..., n_steps-2, n_dims * n_orders, n_dims * n_orders)
    for dim in range(n_dims):
        i1: int = dim * n_orders
        i2: int = (dim + 1) * n_orders
        block_diag_0[..., i1:i2, i1:i2] += smooth_block_diag_0
        block_diag_1[..., i1:i2, i1:i2] = smooth_block_diag_1
        block_diag_2[..., i1:i2, i1:i2] = smooth_block_diag_2

    # blocked cholesky decomposition
    block_diag_0_list: list[torch.Tensor] = list(block_diag_0.unbind(dim=-3))
    block_diag_1_list: list[torch.Tensor] = list(block_diag_1.unbind(dim=-3))
    block_diag_2_list: list[torch.Tensor] = list(block_diag_2.unbind(dim=-3))
    for step in range(n_steps):
        if step >= 2:
            block_diag_2_list[step - 2] = torch.linalg.solve_triangular(
                block_diag_0_list[step - 2].transpose(-2, -1),
                block_diag_2_list[step - 2],
                upper=True,
                left=False,
            )
            block_diag_1_list[step - 1] = block_diag_1_list[step - 1] \
                                          - block_diag_2_list[step - 2] @ block_diag_1_list[step - 2].transpose(-2, -1)
        if step >= 1:
            block_diag_1_list[step - 1] = torch.linalg.solve_triangular(
                block_diag_0_list[step - 1].transpose(-2, -1),
                block_diag_1_list[step - 1],
                upper=True,
                left=False,
            )
            if step >= 2:
                block_diag_0_list[step] = block_diag_0_list[step] \
                                          - block_diag_2_list[step - 2] @ block_diag_2_list[step - 2].transpose(-2, -1)
            block_diag_0_list[step] = block_diag_0_list[step] \
                                      - block_diag_1_list[step - 1] @ block_diag_1_list[step - 1].transpose(-2, -1)
        block_diag_0_list[step], _ = torch.linalg.cholesky_ex(
            block_diag_0_list[step],
            upper=False,
            check_errors=False,
        )

    # A X = B => L (Lt X) = B
    # solve L Y = B, block forward substitution
    b_list: list[torch.Tensor] = list(beta.unbind(dim=-3))
    y_list: list[torch.Tensor | None] = [None] * n_steps
    for step in range(n_steps):
        b_step: torch.Tensor = b_list[step]
        if step >= 2:
            b_step = b_step - block_diag_2_list[step - 2] @ y_list[step - 2]
        if step >= 1:
            b_step = b_step - block_diag_1_list[step - 1] @ y_list[step - 1]
        y_list[step] = torch.linalg.solve_triangular(
            block_diag_0_list[step],
            b_step,
            upper=False,
            left=True,
        )

    # solve Lt X = Y, block backward substitution
    x_list: list[torch.Tensor | None] = [None] * n_steps
    for step in range(n_steps - 1, -1, -1):
        y_step: torch.Tensor = y_list[step]
        if step < n_steps - 2:
            y_step = y_step - block_diag_2_list[step].transpose(-2, -1) @ x_list[step + 2]
        if step < n_steps - 1:
            y_step = y_step - block_diag_1_list[step].transpose(-2, -1) @ x_list[step + 1]
        x_list[step] = torch.linalg.solve_triangular(
            block_diag_0_list[step].transpose(-2, -1),
            y_step,
            upper=True,
            left=True,
        )

    u: torch.Tensor = torch.stack(x_list, dim=-3).reshape(*batch, n_steps, n_dims, n_orders)
    # (..., n_steps, n_dims, n_orders)
    return u


def ode_forward_baseline(
        coefficients: torch.Tensor,
        rhs_equation: torch.Tensor,
        rhs_init: torch.Tensor,
        steps: torch.Tensor,
) -> torch.Tensor:
    """
    coefficients: (..., n_steps, n_equations, n_dims, n_orders)
    rhs_equation: (..., n_steps, n_equations)
    rhs_init: (..., n_init_var_steps, n_dims, n_init_var_orders)
    steps: (..., n_steps-1)
    return: (..., n_steps, n_dims, n_orders)
    """
    dtype: torch.dtype = coefficients.dtype
    device: torch.device = coefficients.device

    *batches, n_steps, n_equations, n_dims, n_orders = coefficients.shape
    *_, n_init_var_steps, _, n_init_var_orders = rhs_init.shape

    A_eq = torch.zeros(*batches, n_steps * n_equations, n_steps * n_dims * n_orders, dtype=dtype, device=device)
    for i, (step, equation) in enumerate(itertools.product(range(n_steps), range(n_equations))):
        A_eq[..., i, step * n_dims * n_orders: (step + 1) * n_dims * n_orders] = coefficients[..., step, equation, :, :].flatten(start_dim=-2)
    beta_eq = rhs_equation.flatten(start_dim=-2)

    A_in = torch.zeros(*batches, n_init_var_steps * n_dims * n_init_var_orders, n_steps * n_dims * n_orders, dtype=dtype, device=device)
    for i, (step, dim, order) in enumerate(itertools.product(range(n_init_var_steps), range(n_dims), range(n_init_var_orders))):
        A_in[..., i, (step * n_dims + dim) * n_orders + order] = 1.
    beta_in = rhs_init.flatten(start_dim=-3)

    A_sf = torch.zeros(*batches, (n_steps - 1) * n_dims * (n_orders - 1), n_steps * n_dims * n_orders, dtype=dtype, device=device)
    for i, (step, dim, order) in enumerate(itertools.product(range(n_steps - 1), range(n_dims), range(n_orders - 1))):
        for o in range(order, n_orders):
            A_sf[..., i, (step * n_dims + dim) * n_orders + o] = steps[..., step] ** o / math.factorial(o - order)
        A_sf[..., i, ((step + 1) * n_dims + dim) * n_orders + order] = - steps[..., step] ** order

    A_sb = torch.zeros(*batches, (n_steps - 1) * n_dims * (n_orders - 1), n_steps * n_dims * n_orders, dtype=dtype, device=device)
    for i, (step, dim, order) in enumerate(itertools.product(range(n_steps - 1), range(n_dims), range(n_orders - 1))):
        for o in range(order, n_orders):
            A_sb[..., i, ((step + 1) * n_dims + dim) * n_orders + o] = (- steps[..., step]) ** o / math.factorial(o - order)
        A_sb[..., i, (step * n_dims + dim) * n_orders + order] = - (- steps[..., step]) ** order

    A_sc = torch.zeros(*batches, (n_steps - 2) * n_dims, n_steps * n_dims * n_orders, dtype=dtype, device=device)
    for i, (step, dim) in enumerate(itertools.product(range(n_steps - 2), range(n_dims))):
        A_sc[..., i, (step * n_dims + dim) * n_orders + (n_orders - 2)] = (steps[..., step] + steps[..., step + 1]) ** (n_orders - 3)
        A_sc[..., i, ((step + 1) * n_dims + dim) * n_orders + (n_orders - 1)] = (steps[..., step] + steps[..., step + 1]) ** (n_orders - 2)
        A_sc[..., i, ((step + 2) * n_dims + dim) * n_orders + (n_orders - 2)] = - (steps[..., step] + steps[..., step + 1]) ** (n_orders - 3)

    A = torch.cat([A_eq, A_in, A_sb, A_sc, A_sf], dim=-2)
    beta = torch.cat([beta_eq, beta_in, torch.zeros_like(A_sf[..., 0]), torch.zeros_like(A_sc[..., 0]), torch.zeros_like(A_sb[..., 0])], dim=-1)

    AtA = A.transpose(-2, -1) @ A
    Atb = A.transpose(-2, -1) @ beta[..., None]

    L, info = torch.linalg.cholesky_ex(AtA, upper=False, check_errors=False)
    u = Atb.cholesky_solve(L, upper=False)
    u = u.reshape(*batches, n_steps, n_dims, n_orders)
    return u


def test() -> None:
    torch.autograd.set_detect_anomaly(mode=True, check_nan=True)
    dtype = torch.float64
    device = torch.device('cuda:0')
    batches = (11,)
    # batches = ()
    n_steps, n_equations, n_dims, n_orders = 7, 2, 3, 5
    n_init_var_steps, n_init_var_orders = 3, 4
    coefficients = torch.nn.Parameter(torch.randn(*batches, n_steps, n_equations, n_dims, n_orders, dtype=dtype, device=device))
    rhs_equation = torch.nn.Parameter(torch.randn(*batches, n_steps, n_equations, dtype=dtype, device=device))
    rhs_init = torch.nn.Parameter(torch.randn(*batches, n_init_var_steps, n_dims, n_init_var_orders, dtype=dtype, device=device))
    steps = torch.nn.Parameter(torch.rand(*batches, n_steps - 1, dtype=dtype, device=device))
    _ = ode_forward(coefficients, rhs_equation, rhs_init, steps)

    coefficients = torch.nn.Parameter(torch.randn(*batches, n_steps, n_equations, n_dims, n_orders, dtype=dtype, device=device))
    rhs_equation = torch.nn.Parameter(torch.randn(*batches, n_steps, n_equations, dtype=dtype, device=device))
    rhs_init = torch.nn.Parameter(torch.randn(*batches, n_init_var_steps, n_dims, n_init_var_orders, dtype=dtype, device=device))
    steps = torch.nn.Parameter(torch.rand(*batches, n_steps - 1, dtype=dtype, device=device))
    u = ode_forward(coefficients, rhs_equation, rhs_init, steps)

    u.sum().backward()
    u0 = ode_forward_baseline(coefficients, rhs_equation, rhs_init, steps)
    diff = u - u0
    print(diff.abs().max().item())

    var_list = [coefficients, rhs_equation, rhs_init, steps]
    grads = [var.grad for var in var_list]
    for var in var_list:
        var.grad = None
    u0.sum().backward()
    grads0 = [var.grad for var in var_list]
    grads_diff = [g - g0 for g, g0 in zip(grads, grads0)]
    print([grad_diff.abs().max().item() for grad_diff in grads_diff])
    u = None


if __name__ == '__main__':
    test()
