"""Test `rla_pinns.forward_laplacian`."""

from test.test_manual_differentiation import CASE_IDS, CASES, set_up
from typing import Callable, Dict, List, Union

from einops import einsum
from pytest import mark
from torch import Tensor, allclose, linspace, ones, outer, rand, randn
from torch.nn import Sequential

from rla_pinns.autodiff_utils import autograd_input_hessian
from rla_pinns.forward_laplacian import manual_forward_laplacian


def reshape_square(t: Tensor) -> Tensor:
    """Reshape an arbitrary tensor into a square matrix.

    Args:
        t: An arbitrary tensor. Number of elements must have an integer square root.

    Returns:
        A square matrix with the same elements as `t`.

    Raises:
        ValueError: If the number of elements does not have an integer square root.
    """
    dim = int(t.numel() ** 0.5)
    if dim**2 != t.numel():
        raise ValueError("Number of elements must have an integer square root.")
    return t.reshape(dim, dim)


COORDINATE_FNS = {
    "coordinates=None": lambda _: None,
    "coordinates=even": lambda X: [i for i in range(X.shape[1]) if i % 2 == 0],
}
COEFFICIENTS_FNS = {
    "coefficients=None": lambda *_: None,
    # symmetric non-diagonal
    "coefficients=equal": lambda X, coordinates: ones(
        X.shape[1:].numel() if coordinates is None else len(coordinates),
        X.shape[1:].numel() if coordinates is None else len(coordinates),
    ).to(X.device, X.dtype),
    # non-symmetric
    "coefficients=linspace": lambda X, coordinates: reshape_square(
        linspace(
            1,
            2,
            (X.shape[1:].numel() if coordinates is None else len(coordinates)) ** 2,
        ).to(X.device, X.dtype)
    ),
    # also includes negative coefficients
    "coefficients=random": lambda X, coordinates: randn(
        X.shape[1:].numel() if coordinates is None else len(coordinates),
        X.shape[1:].numel() if coordinates is None else len(coordinates),
    ).to(X.device, X.dtype),
    # sum of vector outer product (list format)
    "coefficients=list+random": lambda X, coordinates: [
        rand(
            X.shape[1:].numel() if coordinates is None else len(coordinates),
        ).to(X.device, X.dtype)
        for _ in range(5)
    ],
}


@mark.parametrize(
    "coefficients_fn", COEFFICIENTS_FNS.values(), ids=COEFFICIENTS_FNS.keys()
)
@mark.parametrize("coordinate_fn", COORDINATE_FNS.values(), ids=COORDINATE_FNS.keys())
@mark.parametrize("case", CASES, ids=CASE_IDS)
def test_manual_forward_laplacian(
    case: Dict,
    coordinate_fn: Callable[[Tensor], Union[None, List[int]]],
    coefficients_fn: Callable[
        [Tensor, Union[None, List[int]]], Union[None, Tensor, List[Tensor]]
    ],
):
    """Compute forward Laplacian (or weighted second derivatives), check with functorch.

    Args:
        case: A dictionary describing a test case.
        coordinate_fn: A function that takes the input data `X` and returns
            the coordinates whose diagonal entries are summed into the Laplacian.
        coefficients_fn: A function that takes the input data `X` and the coefficients
            generated by `coordinate_fn`, and returns the coefficients for the weighted
            sum of second derivatives.

    Raises:
        AssertionError: If the computed Laplacian does not match the true Laplacian.
        ValueError: If the coefficients format is not supported.
    """
    layers, X = set_up(case)
    coordinates = coordinate_fn(X)
    coefficients = coefficients_fn(X, coordinates)

    # automatic computation (via functorch)
    true_hessian_X = autograd_input_hessian(
        Sequential(*layers), X, coordinates=coordinates
    )
    if coordinates is not None:
        assert true_hessian_X.shape[1:] == (len(coordinates), len(coordinates))

    if coefficients is None:
        true_laplacian_X = einsum(true_hessian_X, "batch d d -> ")
    elif isinstance(coefficients, Tensor):
        true_laplacian_X = einsum(true_hessian_X, coefficients, "batch i j, i j -> ")
    elif isinstance(coefficients, list):
        assert len(coefficients) > 0
        coeff_mat = sum(outer(c, c) for c in coefficients)
        true_laplacian_X = einsum(true_hessian_X, coeff_mat, "batch i j, i j -> ")
    else:
        raise ValueError(f"Invalid coefficients: {coefficients}.")

    # forward-Laplacian computation
    coefficients = manual_forward_laplacian(
        layers, X, coordinates=coordinates, coefficients=coefficients
    )
    laplacian_X = einsum(coefficients[-1]["laplacian"], "n d -> ")

    assert allclose(true_laplacian_X, laplacian_X)
