import torch
import math
import torch.nn.functional as F
import math

class KANLayer(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
        compute_symbolic=False,
        compute_mult = False
    ):
        super(KANLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        out_features_sum = math.ceil(out_features / 2)
        out_features_mult = out_features - out_features_sum
        self.grid_size = grid_size
        self.spline_order = spline_order
        self.compute_symbolic = compute_symbolic
        self.compute_mult = compute_mult

        self.multiplicative_mask = torch.tensor(
            [False] * out_features_sum + [True] * out_features_mult,
            dtype=torch.bool
        )
        
        self.multiplicative_mask = (self.multiplicative_mask) & (self.in_features > 1) & (compute_mult)
        
        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.layer_mask = torch.nn.Parameter(torch.ones((out_features, in_features))).requires_grad_(False)
        self.symb_mask = torch.nn.Parameter(torch.zeros((out_features, in_features))).requires_grad_(False)
        self.symbolic_functions = [[lambda x: 0*x for _ in range(in_features)] for _ in range(out_features)]
        self.cache_act = None
        self.cache_preact = None
        self.symb_dict_names = {}

        self.acts_scale_spline = None
        
        self.init_params()

    def init_params(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        # Each basis function is define in the range [i, i + k + 1]
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        ).cpu()  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1).cpu()  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(A, B).solution  # (in_features, grid_size + spline_order, out_features)
        if torch.cuda.is_available():
            solution = solution.cuda()

        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()


    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))


    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )


    def get_symbolic_output(self, x):
        # May be substituted with a sparse tensor for efficiency
        postacts = []
        for j in range(self.out_features):
            postacts_ = []
            for i in range(self.in_features):
                fn = self.symbolic_functions[j][i]
                input = x[:, i].cpu().detach().numpy()
                x_ji = torch.tensor(fn(input)).to(x.device).expand(x.size(0))
                postacts_.append(x_ji)
            postacts.append(torch.stack(postacts_, dim=1))
        return torch.stack(postacts, dim = 1)


    def get_activations(self, x):
        self.acts_scale_spline = []
        base_act = self.base_activation(x)
        base_output = base_act[:, None, :] * self.base_weight # (batch, out_features, in_features)

        splines = self.b_splines(x) # (batch, in, coeff)
        output_b_spline = torch.einsum('jik, bik -> bji', self.scaled_spline_weight, splines) # (batch, out_features, in_features)

        output_layer = output_b_spline + base_output
        
        if self.compute_symbolic:
            symb_output = self.get_symbolic_output(x) # (batch, out_features, in_features)
            output = self.layer_mask[None, :, :] * output_layer + self.symb_mask[None, :, :] * symb_output
        else:
            output = output_layer

        # Store acts
        self.cache_act = output.detach()
        self.cache_preact = x.detach()
        
        # For regularization loss
        # input_range = torch.mean(torch.abs(x), dim=0) + 1e-8
        output_range_spline, _ = torch.max(torch.abs(output_layer), dim=0)
        # self.acts_scale_spline = output_range_spline / input_range # (out_features, in_features)
        self.acts_scale_spline = output_range_spline # (out_features, in_features)
        
        output_sum = output.sum(dim=2) # (batch, out_features)
        if self.compute_mult:
            sign = torch.sign(output)
            abs_output = torch.clamp(torch.abs(output), min=1e-5)

            log_abs = torch.log(abs_output)
            sum_log = log_abs.sum(dim=2)
            prod_abs = torch.exp(sum_log)

            sign_product = sign.prod(dim=2)
            prod_output = sign_product * prod_abs
            
            mask = self.multiplicative_mask.to(output.device)[None, :]
            out = torch.where(mask, prod_output, output_sum)
        else:
            out = output_sum
        return out


    def get_activations_efficient(self, x):
        original_shape = x.shape
        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output
    
    
    def forward(self, x: torch.Tensor, store_act=False):
        assert x.size(-1) == self.in_features
        assert not (store_act == False and self.compute_mult), "Multiplicative node allowed only with original formulation"
        
        x = x.reshape(-1, self.in_features)

        if store_act:
            output = self.get_activations(x) # (batch, out_features)
        else:
            output = self.get_activations_efficient(x) # (batch, out_features)
        
        return output
    
    
    def regularization_loss_fake(self, regularize_activation=1.0, regularize_entropy=1.0):
        
        l1_spline = self.spline_weight.abs().mean(-1)
        spline_term =  l1_spline.sum()    # shape (out, in)
        
        l1_base = self.base_weight.abs()  # shape (out, in)
        base_term =  l1_base.sum()
        
        regularization_loss_activation = spline_term + base_term
        
        l1_spline_flat = l1_spline.view(-1)
        l1_base_flat = l1_base.view(-1)
        
        l1_total_flat = torch.cat([l1_spline_flat, l1_base_flat], dim=0)
        
        p = l1_total_flat / (regularization_loss_activation + 1e-8)
        regularization_loss_entropy = -torch.sum(p * torch.log(p + 1e-8))
        
        reg_loss = regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy
        return reg_loss, regularization_loss_activation, regularization_loss_entropy


    def regularization_loss_orig(self, mu_1=1.0, mu_2=1.0):
        assert self.acts_scale_spline is not None, 'Cannot use original L1 norm if activations are not saved'

        # Total L1 norm of all activation functions
        l1 = torch.sum(self.acts_scale_spline)

        p = self.acts_scale_spline.view(-1) / (l1 + 1e-8)
        entropy = -torch.sum(p * torch.log(p + 1e-8))

        # Final regularization
        reg = mu_1 * l1 + mu_2 * entropy

        return reg, l1, entropy