# -*- coding: utf-8 -*-
"""
Created on Sun Sep 8 19:54:22 2024

@author: kernel
"""

import torch

def B_batch(x, grid, k=0, extend=True, device='cpu'):
    '''
    Evaluate the B-spline basis functions on input x and knot grid.

    B-splines are piecewise polynomial functions commonly used for curve fitting and smoothing.

    Args:
    -----
        x : torch.Tensor
            Input data of shape (number of splines, number of samples).
        grid : torch.Tensor
            Knot positions of shape (number of splines, number of grid points).
        k : int
            Order of the B-spline. k=0 denotes piecewise constant (zero-order).
        extend : bool
            Whether to extend the grid at both ends by k points (for boundary handling).
        device : str
            Target device ("cpu" or "cuda").

    Returns:
    --------
        torch.Tensor
            Evaluated B-spline basis values of shape (batch, in_dim, G + k).
    '''
    x = x.unsqueeze(dim=2)
    grid = grid.unsqueeze(dim=0)

    if k == 0:
        value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
    else:
        B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1)

        value = (
            (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] +
            (grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
        )

    # Handle potential numerical issues (e.g., division by zero)
    value = torch.nan_to_num(value)
    return value


def coef2curve(x_eval, grid, coef, k, device="cpu"):
    '''
    Evaluate the B-spline curve using coefficients.

    This function computes the B-spline curve values by multiplying the basis functions
    with the given spline coefficients.

    Args:
    -----
        x_eval : torch.Tensor
            Input data of shape (batch, in_dim).
        grid : torch.Tensor
            Knot vector of shape (in_dim, G + 2k).
        coef : torch.Tensor
            Coefficients of shape (in_dim, out_dim, G + k).
        k : int
            Spline order.
        device : str
            Target device.

    Returns:
    --------
        torch.Tensor
            Evaluated curve values of shape (batch, in_dim, out_dim).
    '''
    b_splines = B_batch(x_eval, grid, k=k)
    y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))

    return y_eval


def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
    '''
    Fit B-spline coefficients using regularized least squares.

    Given input-output pairs (x_eval, y_eval), this function computes the B-spline
    coefficients that best fit the data using Ridge regression.

    Args:
    -----
        x_eval : torch.Tensor
            Input data of shape (batch, in_dim).
        y_eval : torch.Tensor
            Target values of shape (batch, in_dim, out_dim).
        grid : torch.Tensor
            Knot vector of shape (in_dim, G + 2k).
        k : int
            Order of the spline.
        lamb : float
            Regularization parameter (λ) for least squares.

    Returns:
    --------
        coef : torch.Tensor
            Fitted B-spline coefficients of shape (in_dim, out_dim, G + k).
    '''
    batch = x_eval.shape[0]
    in_dim = x_eval.shape[1]
    out_dim = y_eval.shape[2]
    n_coef = grid.shape[1] - k - 1

    mat = B_batch(x_eval, grid, k)
    mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef)
    y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3)
    device = mat.device

    # Compute A = X^T X + λI and B = X^T y
    XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0, 1, 3, 2), mat)
    Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0, 1, 3, 2), y_eval)

    n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
    identity = torch.eye(n, n)[None, None, :, :].expand(n1, n2, n, n).to(device)
    A = XtX + lamb * identity
    B = Xty

    coef = (A.pinverse() @ B)[:, :, :, 0]
    return coef


def extend_grid(grid, k_extend=0):
    '''
    Extend the B-spline knot vector at both ends.

    This function ensures the spline basis is well-defined at boundaries
    by adding k_extend points symmetrically on both sides.

    Args:
    -----
        grid : torch.Tensor
            Original knot vector of shape (in_dim, num_knots).
        k_extend : int
            Number of extra knots to add on each side.

    Returns:
    --------
        torch.Tensor
            Extended knot vector.
    '''
    h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)

    for i in range(k_extend):
        grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
        grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)

    return grid