import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class FKANLayer(nn.Module): # Fourier-KAN layer used in the Fourier-KAN Block (cf. Figure 2 & Eq (7))
    """
    Fourier Kolmogorov-Arnold Network (Fourier-KAN) Layer\n
    Applies Fourier basis functions: cos(𝑘𝑥) and sin(𝑘𝑥) (cf. Eq (8)).
    """
    def __init__(self, in_features, out_features, gridsize, add_bias=False):
        super(FKANLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.gridsize = gridsize # Hyperparameter "g" in Appendix (cf. Appendix F)
        self.add_bias = add_bias
        self.grid_norm_factor = (torch.arange(self.gridsize) + 1)**2
        self.fouriercoeffs = nn.Parameter(torch.randn(2, self.out_features, self.in_features, self.gridsize) /
                                          (np.sqrt(self.in_features) * self.grid_norm_factor))
        if self.add_bias:
            self.bias = nn.Parameter(torch.zeros(1, self.out_features))
        
    def forward(self, x):
        xshp = x.shape
        outshape = xshp[0:-1] + (self.out_features,)
        x = torch.reshape(x, (-1, self.in_features))
        k = torch.reshape(torch.arange(1, self.gridsize + 1, device=x.device), (1, 1, 1, self.gridsize))
        xrshp = torch.reshape(x, (x.shape[0], 1, x.shape[1], 1))
        # cosine and sine in Fourier Basis (cf. Figure 2 & Eq (8))
        c = torch.cos(k * xrshp) # cos(𝑘𝑥)
        s = torch.sin(k * xrshp) # sin(𝑘𝑥)
        c = torch.reshape(c, (1, x.shape[0], x.shape[1], self.gridsize))
        s = torch.reshape(s, (1, x.shape[0], x.shape[1], self.gridsize))
        y = torch.einsum("dbik,djik->bj", torch.concat([c, s], axis=0), self.fouriercoeffs)
        if self.add_bias:
            y += self.bias
        y = torch.reshape(y, outshape)
        return y


class CKANLayer(nn.Module): # Cheby-KAN Layer used in Ablation Study (cf. Table 2)
    """
    Chebyshev Kolmogorov-Arnold Network (Cheby-KAN) Layer\n
    Chebyshev Polynomial = cos(𝑛⋅arccos(𝑥))
    """
    def __init__(self, in_features, out_features, degree):
        super(CKANLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.degree = degree
        self.cheby_coeffs = nn.Parameter(torch.empty(self.in_features, self.out_features, self.degree + 1))
        nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (self.in_features * (self.degree + 1)))
        self.register_buffer("arange", torch.arange(0, self.degree + 1, 1))

    def forward(self, x):
        outshape = x.shape
        x = torch.tanh(x)
        x = x.reshape(-1, self.in_features, 1).expand(-1, -1, self.degree + 1)
        x = x.acos()
        x *= self.arange
        x = x.cos()
        y = torch.einsum("bid,iod->bo", x, self.cheby_coeffs)
        y = torch.reshape(y, outshape)
        return y


class VanillaKANLayer(nn.Module): # Vanilla-KAN Layer used in Ablation Study (cf. Table 2)
    """
    Vanilla Kolmogorov-Arnold Network (Vanilla-KAN) Layer (B-Spline basis)\n
    B-Spline Function = c⋅B(x)
    """
    # Some code based on https://github.com/engichang1467/Simple-KAN/blob/main/model.py
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        grid_step = 2 / grid_size
        grid_range = torch.arange(-spline_order, grid_size + spline_order + 1)
        grid_values = grid_range * grid_step - 1
        self.grid = grid_values.expand(in_features, -1).contiguous()
        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
        self.spline_scaler = nn.Parameter(torch.Tensor(out_features, in_features))
        self.base_activation = nn.SiLU()
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))
        with torch.no_grad():
            noise_shape = (self.grid_size + 1, self.in_features, self.out_features)
            random_noise = (torch.rand(noise_shape) - 0.5) * 0.1 / self.grid_size
            grid_points = self.grid.T[self.spline_order : -self.spline_order]
            spline_coefficients = self.curve2coeff(grid_points, random_noise)
            self.spline_weight.data.copy_(spline_coefficients)
        nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5))

    def b_splines(self, x: torch.Tensor):
        expanded_grid = (self.grid.unsqueeze(0).expand(x.size(0), *self.grid.size()).to(self.device))
        input_tensor_expanded = x.unsqueeze(-1).to(self.device)
        bases = (
            (input_tensor_expanded >= expanded_grid[:, :, :-1])
            & (input_tensor_expanded < expanded_grid[:, :, 1:])
        ).to(x.dtype)
        for order in range(1, self.spline_order + 1):
            left_term = (
                (input_tensor_expanded - expanded_grid[:, :, : -order - 1])
                / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
            ) * bases[:, :, :-1]

            right_term = (
                (expanded_grid[:, :, order + 1 :] - input_tensor_expanded)
                / (expanded_grid[:, :, order + 1 :] - expanded_grid[:, :, 1:-order])
            ) * bases[:, :, 1:]
            bases = left_term + right_term
        return bases.contiguous()

    def curve2coeff(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor):
        b_splines_bases = self.b_splines(input_tensor)
        transposed_bases = b_splines_bases.transpose(0, 1)
        transposed_output = output_tensor.transpose(0, 1)
        transposed_bases = transposed_bases.to(self.device)
        transposed_output = transposed_output.to(self.device)
        coefficients_solution = torch.linalg.lstsq(transposed_bases, transposed_output).solution
        coefficients = coefficients_solution.permute(2, 0, 1)
        return coefficients.contiguous()

    def forward(self, x: torch.Tensor):
        original_shape = x.shape
        x = x.contiguous().view(-1, self.in_features)
        base_output = F.linear(self.base_activation(x).to(self.device), self.base_weight.to(self.device))
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1).to(self.device),
            self.spline_weight.view(self.out_features, -1).to(self.device),
        )
        output = base_output + spline_output
        output = output.view(*original_shape[:-1], -1)
        return output