"""Datasets used to investigate fast convergence conditions of NGD."""

from math import pi
from typing import Tuple
from warnings import warn

from torch import Generator, Tensor, allclose, diag, eye, rand, randn
from torch.linalg import eigh


def forster_transform(
    X: Tensor, maxiter: int = 1000, atol: float = 1e-5, rtol: float = 1e-3
) -> Tensor:
    """Find a transformation of the data used to obtain diagonal covariance.

    See Appendix A (Algorithm 1) of Zhang et al for details.

    Args:
        X: Data set of shape ``[n, d]``.
        maxiter: Maximum number of iterations. The algorithm will stop after this
            number of iterations even if the specified tolerance has not been reached.
            Default: ``1000``.
        atol: Absolute tolerance for convergence. Default: ``1e-5``.
        rtol: Relative tolerance for convergence. Default: ``1e-3``.

    Returns:
        A matrix A of shape ``[d, d]`` such that the transformed data if we construct
        ``Z = X @ A``, then normalize each row of ``Z`` by it's L2-norm, the resulting
        data set ``Z'`` satisfies ``Z'ᵀ Z' = (n / d) I_d`` within the specified
        tolerances (no guarantees if maximum number of iterations were exceeded).
    """
    n, d = X.shape
    Z = X
    A = eye(d, device=X.device, dtype=X.dtype)

    step = 0

    def _converged() -> bool:
        """Termination criterion for Forster transform iterations.

        Returns:
            True if the learned transformation satisfies the specified tolerances
            or if the maximum number of iterations has been reached. Otherwise False.
        """
        Z_step = X @ A
        Z_step /= Z_step.norm(dim=1, keepdim=True)
        ZTZ_truth = (n / d) * eye(d, device=X.device, dtype=X.dtype)
        tol_reached = allclose(Z_step.T @ Z_step, ZTZ_truth, atol=atol, rtol=rtol)

        steps_exceeded = step >= maxiter
        if steps_exceeded:
            warn("Maxiter exceeded. Specified tolerance not reached.")

        return tol_reached or steps_exceeded

    while not _converged():
        # matrix square root via eigendecomposition
        evals, evecs = eigh(Z.T @ Z)
        T = evecs @ diag(1.0 / evals.sqrt()) @ evecs.T

        Z = Z @ T
        Z /= Z.norm(dim=1, keepdim=True)
        A = A @ T
        A /= A[0, 0].clone()

        step += 1

    return A


def synthetic_1d_regression(
    num_data: int, seed: int = 123456789
) -> Tuple[Tensor, Tensor]:
    """Create synthetic one-dimensional regression data set.

    Inputs are 2d, targets are 1d. Each input ``x`` is drawn from a uniform
    distribution. The associated target is created via
    ``y = sin(2π * (x₁ + x₂)) + 0.1 * ε`` where ``ε`` is drawn from a standard normal
    distribution.

    Args:
        num_data: Number of data points.
        seed: Random seed used internally. Default: ``123456789``.

    Returns:
        Tuple of inputs and targets. Inputs are of shape ``[num_data, 2]``, targets
        are of shape ``[num_data, 1]``.
    """
    generator = Generator().manual_seed(seed)
    X = rand((num_data, 2), generator=generator)
    y = (2 * pi * X.sum(1, keepdim=True)).sin() + 0.1 * randn(
        (num_data, 1), generator=generator
    )

    return X, y
