import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from kans.kan import *

# 算阶乘
def jc(x):
    return math.factorial(x)


def B_batch(x, grid, layer_type, n, in_dim, out_dim, extend=True, device='cpu'):
    '''
    evaludate x on B-spline bases
    
    Args:
    -----
        x : 2D torch.tensor
            inputs, shape (number of splines, number of samples)
        grid : 2D torch.tensor
            grids, shape (number of splines, number of grid points)
        Bn : int
            the piecewise polynomial order of splines.
        extend : bool
            If True, Bn points are extended on both ends. If False, no extension (zero boundary condition). Default: True
        device : str
            devicde
    
    Returns:
    --------
        spline values : 3D torch.tensor
            shape (batch, in_dim, G+Bn). G: the number of grid intervals, Bn: spline order.
      
    Example
    -------
    >>> from kan.spline import B_batch
    >>> x = torch.rand(100,2)
    >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
    >>> B_batch(x, grid, Bn=3).shape
    '''
    
    # x = x.unsqueeze(dim=2)
    grid = grid.unsqueeze(dim=0)
    value = torch.zeros_like(x)
    if n == 0:
        x = x.unsqueeze(dim=2)
        value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
    else:
        if layer_type == 'b_spline':
            x = x.unsqueeze(dim=2)
            B_km1 = B_batch(x[:, :, 0], grid=grid[0], layer_type=layer_type, n=n - 1, in_dim=in_dim, out_dim=out_dim, device=device)
            # 这里是B样条
            value = (x - grid[:, :, :-(n + 1)]) / (grid[:, :, n:-1] - grid[:, :, :-(n + 1)]) * B_km1[:, :, :-1] + (
                        grid[:, :, n + 1:] - x) / (grid[:, :, n + 1:] - grid[:, :, 1:(-n)]) * B_km1[:, :, 1:]

        elif layer_type == 'BN':
            x = x.unsqueeze(dim=2)
            value = torch.zeros_like(x)
            is_N = True
            if is_N:
                for k in range(n + 1):  # 遍历从0到n（包括n）
                    C_nk = jc(n) / (jc(k) * jc(n - k))
                    value+= C_nk * (x ** k) * ((1 - x) ** (n - k))
            else:
                value = torch.zeros_like(x)
                for k in range(n // 2 + 1):  # 遍历从0到n（包括n）
                    C_nk = jc(n) / (jc(k) * jc(n - k))
                    value += C_nk * (x ** k) * ((1 - x) ** (n - k))
        elif layer_type == "Cheby":
            kan = KACN([in_dim, out_dim, in_dim])
            value = kan(x)
            value = value.unsqueeze(dim=2)

        elif layer_type == "Fast":
            kan = FastKAN([in_dim, out_dim, in_dim])
            value = kan(x)
            value = value.unsqueeze(dim=2)
            # value = fast(x)

        elif layer_type == "Wave":
            kan = WavKAN([in_dim, out_dim, in_dim])
            value = kan(x)
            value = value.unsqueeze(dim=2)

        elif layer_type == "Jacobi":
            kan = KAJN([in_dim, out_dim, in_dim])
            value = kan(x)
            value = value.unsqueeze(dim=2)

        elif layer_type == "B01":
            kan = KABN([in_dim, out_dim, in_dim])
            value = kan(x)
            value = value.unsqueeze(dim=2)

        elif layer_type == "ReLU":
            kan = ReLUKAN([in_dim, out_dim, in_dim])
            value = kan(x)
            value = value.unsqueeze(dim=2)

        elif layer_type == "RBF":
            kan = RBFKAN(in_dim, out_dim, in_dim, 100)
            value = kan(x)
            value = value.unsqueeze(dim=2)

        elif layer_type == "Fourier":
            gridsize = 300
            xshp = x.shape
            out_dim = in_dim
            outshape = xshp[0:-1] + (out_dim,)
            x = x.view(-1, in_dim)
            # Starting at 1 because constant terms are in the bias
            k = torch.reshape(torch.arange(1, gridsize + 1, device=x.device), (1, 1, 1, gridsize))
            xrshp = x.view(x.shape[0], 1, x.shape[1], 1)
            # This should be fused to avoid materializing memory
            c = torch.cos(k * xrshp)
            s = torch.sin(k * xrshp)
            c = torch.reshape(c, (1, x.shape[0], x.shape[1], gridsize))
            s = torch.reshape(s, (1, x.shape[0], x.shape[1], gridsize))
            fouriercoeffs = nn.Parameter(torch.randn(2, out_dim, in_dim, gridsize) /
                                              (np.sqrt(in_dim) * np.sqrt(gridsize)))
            fouriercoeffs = nn.Parameter(fouriercoeffs.to(x.device))
            y = torch.einsum("dbik,djik->bj", torch.concat([c, s], axis=0),fouriercoeffs)

            y = y.view(outshape)
            y = y.unsqueeze(dim=2)
            value = y
        elif layer_type == "Taylor":
            order = n
            shape = x.shape
            out_dim = in_dim
            outshape = shape[0:-1] + (out_dim,)
            x = torch.reshape(x, (-1, in_dim))

            x_expanded = x.unsqueeze(1).expand(-1, out_dim, -1)

            # Compute and accumulate each term of the Taylor expansion
            y = torch.zeros((x.shape[0], out_dim), device=x.device)
            coeffs = nn.Parameter(torch.randn(out_dim, in_dim, order) * 0.01).to(x.device)
            bias = nn.Parameter(torch.zeros(1, out_dim)).to(x.device)
            for i in range(order):
                term = (x_expanded ** i) * coeffs[:, :, i]
                y += term.sum(dim=-1)
                y += bias
            y = torch.reshape(y, outshape)
            y = y.unsqueeze(dim=2)
            value = y

    # in case grid is degenerate
    value = torch.nan_to_num(value)
    return value



def coef2curve(x_eval, grid, coef, k, layer_type, in_dim, out_dim, device="cpu"):
    '''
    converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
    
    Args:
    -----
        x_eval : 2D torch.tensor
            shape (batch, in_dim)
        grid : 2D torch.tensor
            shape (in_dim, G+2k). G: the number of grid intervals; Bn: spline order.
        coef : 3D torch.tensor
            shape (in_dim, out_dim, G+Bn)
        Bn : int
            the piecewise polynomial order of splines.
        device : str
            devicde
        
    Returns:
    --------
        y_eval : 3D torch.tensor
            shape (number of samples, in_dim, out_dim)
        
    '''
    
    b_splines = B_batch(x_eval, grid, layer_type, n=k, in_dim=in_dim, out_dim=out_dim, device=device)
    y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
    
    return y_eval


def curve2coef(x_eval, y_eval, grid, k, type_layer, in_dim, out_dim, device, lamb=1e-8):
    '''
    converting B-spline curves to B-spline coefficients using least squares.
    
    Args:
    -----
        x_eval : 2D torch.tensor
            shape (in_dim, out_dim, number of samples)
        y_eval : 2D torch.tensor
            shape (in_dim, out_dim, number of samples)
        grid : 2D torch.tensor
            shape (in_dim, grid+2*Bn)
        Bn : int
            spline order
        lamb : float
            regularized least square lambda
            
    Returns:
    --------
        coef : 3D torch.tensor
            shape (in_dim, out_dim, G+Bn)
    '''
    batch = x_eval.shape[0]
    in_dim = x_eval.shape[1]
    out_dim = y_eval.shape[2]
    n_coef = grid.shape[1] - k - 1
    mat = B_batch(x_eval, grid, type_layer, k, in_dim, out_dim, device)

    mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
    y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
    device = mat.device
    # print(device)
    #coef = torch.linalg.lstsq(mat, y_eval,
                             #driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
        
    XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
    Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
    n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
    identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
    A = XtX + lamb * identity
    B = Xty
    coef = (A.pinverse() @ B)[:,:,:,0]
    
    return coef


def extend_grid(grid, k_extend=0):
    '''
    extend grid
    '''
    h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)

    for i in range(k_extend):
        grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
        grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)

    return grid