import torch
import math
import sys
import convolution
import torch
import torch.nn as nn
import torch.nn.functional as F

from KA_Layer import KANLinear

from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class KANConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(KANConv2d_2, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        self.stride = (stride, stride) if isinstance(stride, int) else stride
        self.padding = (padding, padding) if isinstance(padding, int) else padding
        self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation

        groups=min([groups, max(1,in_channels // 16, out_channels // 16)])
        self.groups = groups
        
        if in_channels % groups != 0:
            raise ValueError(f"in_channels {in_channels} must be divisible by groups {groups}")
        if out_channels % groups != 0:
            raise ValueError(f"out_channels {out_channels} must be divisible by groups {groups}")
        
        self.unfold = nn.Unfold(
            kernel_size=self.kernel_size,
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride
        )
        self.in_channels_per_group = in_channels // groups
        self.out_channels_per_group = out_channels // groups
        self.conv_flat_dim_per_group = self.in_channels_per_group * self.kernel_size[0] * self.kernel_size[1]

        grid_size=5
        spline_order = 3
        self.kanlinear = nn.ModuleList([
            KANLinear(
                in_features=self.conv_flat_dim_per_group,
                out_features=self.out_channels_per_group,
                grid_size=grid_size,
                spline_order=spline_order,
                scale_noise=0.1,
                scale_base=1,
                scale_spline=1,
                base_activation=torch.nn.SiLU,
                grid_eps=0.02,
                grid_range=[-1, 1],
            ) for _ in range(groups)
        ])
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            fan_out = self.kernel_size[0] * self.kernel_size[1] * self.out_channels
            fan_out //= self.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        batch_size, in_channels, height, width = x.shape

        x_unfolded = self.unfold(x)
        num_windows = x_unfolded.shape[-1]
        
        out_h = (height + self.padding[0]*2 - self.dilation[0]*(self.kernel_size[0]-1) - 1) // self.stride[0] + 1
        out_w = (width + self.padding[1]*2 - self.dilation[1]*(self.kernel_size[1]-1) - 1) // self.stride[1] + 1
        
        x_unfolded_grouped = x_unfolded.view(batch_size, self.groups, self.conv_flat_dim_per_group, num_windows)
        
        group_outputs = []
        for i in range(self.groups):
            x_group = x_unfolded_grouped[:, i, :, :]
            
            x_group_reshaped = x_group.permute(0, 2, 1).reshape(-1, self.conv_flat_dim_per_group)
            
            kan_output_group = self.kanlinear[i](x_group_reshaped)
            group_outputs.append(kan_output_group)
        
        kan_output = torch.cat(group_outputs, dim=1)
        
        output = kan_output.view(batch_size, out_h, out_w, self.out_channels).permute(0, 3, 1, 2)
        
        return output
