import torch
import torch.nn.functional as F
import math


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
        block_size: int = 16,
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        self.block_size = 32
        self.grad_stats=[]
        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features,1, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features,1)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps
        self.norm = torch.nn.LayerNorm(self.in_features)
        self.reset_parameters()

    def reset_parameters(self):

        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)[:,None,1,:]
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order,None,0],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines_sum(self, x: torch.Tensor):
        with torch.no_grad():
            grid: torch.Tensor = (
                self.grid
            )
            grid = grid[0, :]
            bases=0
            num_blocks = (self.in_features + self.block_size - 1) // self.block_size

            for i in range(num_blocks):
                start_idx = i * self.block_size
                end_idx = min((i + 1) * self.block_size, self.in_features)
                xi=x[:,start_idx:end_idx]
                xi = xi.unsqueeze(-1)
                basesi = ((xi >= grid[:-1]) & (xi < grid[1:])).to(x.dtype)
                for k in range(1, self.spline_order + 1):
                    basesi = (
                        (xi - grid[ : -(k + 1)])
                        / (grid[ k:-1] - grid[ : -(k + 1)])
                        * basesi[:,:, :-1]
                    ) + (
                        (grid[ k + 1 :] - xi)
                        / (grid[ k + 1 :] - grid[ 1:(-k)])
                        * basesi[:, :, 1:]
                    )
                bases+=basesi.sum(dim=1)
            return bases.contiguous().detach()
    
    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        A = self.b_splines(x).transpose(
            0, 1
        )
        B = y.transpose(0, 1)
        solution = torch.linalg.lstsq(
            A, B
        ).solution
        result = solution.permute(
            2, 0, 1
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )
    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        # assert x.dim() == 2 and x.size(1) == self.in_features
        with torch.no_grad():
            grid: torch.Tensor = (
                self.grid
            )  # (in_features, grid_size + 2 * spline_order + 1)
            grid = grid[:x.size(1),:]
            x = x.unsqueeze(-1)
            bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
            # print(x.shape,self.in_features,bases.shape,grid.shape)
            for k in range(1, self.spline_order + 1):
                bases = (
                    (x - grid[:, : -(k + 1)])
                    / (grid[:, k:-1] - grid[:, : -(k + 1)])
                    * bases[:, :, :-1]
                ) + (
                    (grid[:, k + 1 :] - x)
                    / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                    * bases[:, :, 1:]
                )
            # torch.cuda.empty_cache() if x.is_cuda else None
            # del grid,x

            # assert bases.size() == (
            #     x.size(0),
            #     self.in_features,
            #     self.grid_size + self.spline_order,
            # )
            # return bases.contiguous()
            return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        # assert x.dim() == 2 and x.size(1) == 1
        # assert y.size() == (x.size(0), 1, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        return result.contiguous()

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features
        base_output = F.linear(self.base_activation(x), self.base_weight)

        b_splines=self.b_splines_sum(x)

        spline_output = F.linear(
            b_splines.view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )

        return base_output + spline_output

