import math
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_

class Frequency(nn.Module):
    def __init__(self, in_channels, reduction=2, scales=[1,2,4]):
        super().__init__()
        self.in_channels = in_channels
        self.scales = scales
        
        # 多尺度频域处理
        self.scale_branches = nn.ModuleList([
            nn.Sequential(
                nn.AvgPool2d(scale, ceil_mode=True),
                nn.Conv2d(in_channels, in_channels//2, 1),
                nn.ReLU()
            ) for scale in scales
        ])
        
        # 动态频带选择
        self.band_selector = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels//reduction, 1),
            nn.LayerNorm([in_channels//reduction, 1, 1]),
            nn.GELU(),
            nn.Conv2d(in_channels//reduction, in_channels*2, 1)
        )
        
        # 交叉注意力机制
        self.cross_att = nn.Sequential(
            nn.Conv2d(in_channels*2, in_channels//4, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(in_channels//4, 2, 3, padding=1),
            nn.Sigmoid()
        )
        
        # 逆变换后处理
        self.post_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, in_channels, 1)
        )

    def freq_enhance(self, x):
        """核心频域处理流程"""
        B, C, W, H = x.shape
        
        # 多尺度频域特征提取
        scale_feats = []
        for i, branch in enumerate(self.scale_branches):
            x_scale = branch[0](x)  # 下采样
            x_fft = torch.fft.rfft2(x_scale, dim=(-2,-1), norm='ortho')
            
            # 动态频带选择
            mag = torch.abs(x_fft)
            phase = torch.angle(x_fft)
            band_params = self.band_selector(x_scale).chunk(2, dim=1)
            mag = mag * (1 + band_params[0].view(B,-1,1,1))
            phase = phase + 0.1 * band_params[1].view(B,-1,1,1)
            
            # 逆变换与上采样
            x_trans = torch.fft.irfft2(mag*torch.exp(1j*phase), s=x_scale.shape[-2:])
            x_trans = F.interpolate(x_trans, size=(W,H), mode='bilinear')
            scale_feats.append(x_trans)
        
        return torch.stack(scale_feats).mean(dim=0)

    def forward(self, x):
        # 主路径处理
        x_freq = self.freq_enhance(x)
        
        # 交叉注意力融合
        att = self.cross_att(torch.cat([x, x_freq], dim=1))
        out = att[:,0:1] * x + att[:,1:2] * x_freq
        
        # 残差连接
        return self.post_conv(out) + x

class S2KANFTT(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        self.fc1 = KANLinear(in_features, hidden_features)
        self.fc2 = KANLinear(in_features, hidden_features)
        self.freq1 = nn.Sequential(
            SingleConvLayer(in_features, out_features),
            Frequency(out_features),
            nn.GELU()
        )
        self.freq2 = nn.Sequential(
            SingleConvLayer(in_features, out_features),
            Frequency(out_features),
            nn.GELU()
        )
        self.fusion1 = SingleConvLayer(hidden_features*2, out_features)
        self.fusion2 = SingleConvLayer(out_features*3, out_features)
        self.drop = nn.Dropout(drop)
        self.act = act_layer()
        self.shift_size = shift_size
        self.pad = shift_size // 2
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        B, C, H, W = x.shape
        # S2-KAN-FTT
        xs = torch.chunk(x, C, 1)
        x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        x_s = x_cat.reshape(B, C, H * W).contiguous()
        x_shift_r = x_s.transpose(1, 2)
        x_shift_r = self.fc1(x_shift_r)
        x_1 = self.act(x_shift_r)
        x_1 = self.drop(x_1).transpose(1, 2).reshape(B, C, H, W).contiguous()
        xs = torch.chunk(x_1, C, 1)
        x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        x_s = x_cat.reshape(B, C, H * W).contiguous()
        x_shift_c = x_s.transpose(1, 2)
        x_shift_c = self.fc2(x_shift_c)
        x_2 = self.act(x_shift_c)
        x_2 = self.drop(x_2).transpose(1, 2).reshape(B, C, H, W).contiguous()
        x1 = torch.add(x_1, x_2)

        xs = torch.chunk(x, C, 1)
        x_shift = [torch.roll(x_c, -shift, 3) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        x_s = x_cat.reshape(B, C, H * W).contiguous()
        x_shift_r = x_s.transpose(1, 2)
        x_shift_r = self.fc1(x_shift_r)
        x_1 = self.act(x_shift_r)
        x_1 = self.drop(x_1).transpose(1, 2).reshape(B, C, H, W).contiguous()
        xs = torch.chunk(x_1, C, 1)
        x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        x_s = x_cat.reshape(B, C, H * W).contiguous()
        x_shift_c = x_s.transpose(1, 2)
        x_shift_c = self.fc2(x_shift_c)
        x_2 = self.act(x_shift_c)
        x_2 = self.drop(x_2).transpose(1, 2).reshape(B, C, H, W).contiguous()
        x2 = torch.add(x_1, x_2)

        xs = torch.chunk(x, C, 1)
        x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        freq_1 = self.freq1(x_cat)
        freq_1 = self.act(freq_1)
        freq_1 = self.drop(freq_1)
        xs = torch.chunk(freq_1, C, 1)
        x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        freq_2 = self.freq2(x_cat)
        freq_2 = self.act(freq_2)
        freq_2 = self.drop(freq_2)
        x3 = torch.add(freq_1, freq_2)

        xs = torch.chunk(x, C, 1)
        x_shift = [torch.roll(x_c, -shift, 3) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        freq_1 = self.freq1(x_cat)
        freq_1 = self.act(freq_1)
        freq_1 = self.drop(freq_1)
        xs = torch.chunk(freq_1, C, 1)
        x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(0, C))]
        x_cat = torch.cat(x_shift, 1)
        freq_2 = self.freq2(x_cat)
        freq_2 = self.act(freq_2)
        freq_2 = self.drop(freq_2)
        x4 = torch.add(freq_1, freq_2)

        fuse1 = self.fusion1(torch.cat([x1, x2], dim=1))
        fuse2 = self.fusion2(torch.cat([fuse1, x3, x4], dim=1))
        out = self.drop(fuse2)
        return out

class SFABlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SFABlock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, 1, stride=1, padding=0),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            S2KANFTT(output_dim),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(inplace=True)
        )
        self.skip = nn.Conv2d(input_dim, output_dim, 1) if input_dim != output_dim else nn.Identity()
    def forward(self, x):
        identity = self.skip(x)
        x = self.conv1(x)
        x = self.conv2(x) + identity
        return x

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

class SingleConvLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SingleConvLayer, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class Model(nn.Module):
    def __init__(self, num_classes=2, input_channels=3, c=[32, 64, 128, 256, 512]): # 32, 64, 128, 256, 512
        super(Model, self).__init__()
        
        # Encoding path
        self.enc1 = nn.Sequential(
            SingleConvLayer(input_channels, c[0])
        )
        self.enc2 = nn.Sequential(
            SingleConvLayer(c[0], c[1])
        )
        self.enc3 = nn.Sequential(
            SingleConvLayer(c[1], c[2])
        )
        self.enc4 = nn.Sequential(
            SingleConvLayer(c[2], c[3])
        )
        self.enc5 = nn.Sequential(
            SingleConvLayer(c[3], c[4])
        )

        self.s2kanftt = nn.Sequential(
            SFABlock(c[4], c[4])
        )

        # Decoding path
        self.upconv4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic'),
            SingleConvLayer(c[4], c[3])
        )
        self.upconv3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic'),
            SingleConvLayer(c[3], c[2])
        )
        self.upconv2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic'),
            SingleConvLayer(c[2], c[1])
        )
        self.upconv1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bicubic'),
            SingleConvLayer(c[1], c[0])
        )
        self.final_conv = nn.Conv2d(c[0], num_classes, 3, stride=1, padding=1)
    
    def forward(self, x):
        # Encoding
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        enc5 = self.enc5(F.max_pool2d(enc4, 2))

        dec5 = self.s2kanftt(enc5)

        up4 = self.upconv4(dec5)
        dec4 = torch.add(up4, enc4)
        
        up3 = self.upconv3(dec4)
        dec3 = torch.add(up3, enc3)
        
        up2 = self.upconv2(dec3)
        dec2 = torch.add(up2, enc2)
        
        up1 = self.upconv1(dec2)
        dec1 = torch.add(up1, enc1)
        
        out = self.final_conv(dec1)
        return out
