import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import timm
import math
from .cnn_parts import *
from .sclt import *
from .utils import *
from torch.nn.init import trunc_normal_
from collections import OrderedDict


class KANLinear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1, dtype=torch.float32) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = self.grid  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            denom1 = grid[:, k:-1] - grid[:, : -(k + 1)] + 1e-8
            denom2 = grid[:, k + 1 :] - grid[:, 1:(-k)] + 1e-8
            bases = (
                (x - grid[:, : -(k + 1)]) / denom1 * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x) / denom2 * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(0, 1)  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1).contiguous(),  # Ensure contiguous
            self.scaled_spline_weight.view(self.out_features, -1).contiguous(),  # Ensure contiguous
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2).contiguous()  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0).contiguous()  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        ).contiguous()  # (batch, in, out)

        # Sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.long, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.cat(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * torch.log(p + 1e-8))
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

class KANLayer(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., no_kan=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features

        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]

        if not no_kan:
            self.fc1 = KANLinear(
                        in_features,
                        hidden_features,
                        grid_size=grid_size,
                        spline_order=spline_order,
                        scale_noise=scale_noise,
                        scale_base=scale_base,
                        scale_spline=scale_spline,
                        base_activation=base_activation,
                        grid_eps=grid_eps,
                        grid_range=grid_range,
                    )
            self.fc2 = KANLinear(
                        hidden_features,
                        out_features,
                        grid_size=grid_size,
                        spline_order=spline_order,
                        scale_noise=scale_noise,
                        scale_base=scale_base,
                        scale_spline=scale_spline,
                        base_activation=base_activation,
                        grid_eps=grid_eps,
                        grid_range=grid_range,
                    )
        else:
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.fc2 = nn.Linear(hidden_features, out_features)

        self.dwconv = DW_bn_relu(hidden_features)

        self.drop = nn.Dropout(drop)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape

        x = self.fc1(x.reshape(B*N, C))
        x = x.reshape(B, N, -1).contiguous()  # Ensure contiguous
        x = self.dwconv(x, H, W)
        x = self.fc2(x.reshape(B*N, -1).contiguous())  # Ensure contiguous
        x = x.reshape(B, N, -1).contiguous()  # Ensure contiguous

        return x

class KANBlock(nn.Module):
    def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False):
        super().__init__()

        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)

        self.layer = KANLayer(in_features=dim, hidden_features=dim, act_layer=act_layer, drop=drop, no_kan=no_kan)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.layer(self.norm2(x), H, W))
        return x
    
class DW_bn_relu(nn.Module):
    def __init__(self, dim=768):
        super(DW_bn_relu, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
        self.bn = nn.BatchNorm2d(dim)
        self.relu = nn.ReLU()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).contiguous().view(B, C, H, W)  # Ensure contiguous
        x = self.dwconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.flatten(2).transpose(1, 2).contiguous()  # Ensure contiguous

        return x

class DimensionMatchingLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DimensionMatchingLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv(x)
        return x 

class FF(nn.Module):
    def __init__(self, feature_channels, output_dim):
        super(FF, self).__init__()
        self.feature_channels = feature_channels
        self.output_dim = output_dim
        self.conv3x3 = nn.Conv2d(feature_channels * 2, feature_channels, kernel_size=3, padding=1)
        self.conv1x1 = nn.Conv2d(feature_channels, output_dim, kernel_size=1)
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(output_dim, output_dim // 8, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_dim // 8, output_dim, kernel_size=1),
            nn.Sigmoid()
        )
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, e1, y1):
        concat_features = torch.cat((e1, y1), dim=1)
        g = self.conv3x3(concat_features)
        g = self.conv1x1(g)
        m = self.channel_attention(g)
        t_weighted = e1 * m
        f_weighted = y1 * m
        z_af = t_weighted + f_weighted
        return z_af

class FrequencySpatialAttention(nn.Module):  # FSA
    def __init__(self, in_channels):
        super(FrequencySpatialAttention, self).__init__()
        self.in_channels = in_channels
        self.learnable_filter = nn.Parameter(torch.randn(1, in_channels, 1, 1))
        self.conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)

    def forward(self, x):
        device = x.device
        x_dft = torch.fft.fft2(x)
        high_mask = self.create_high_mask(x_dft.size(), device)
        low_mask = self.create_low_mask(x_dft.size(), device)
        high_freq = x_dft * high_mask
        low_freq = x_dft * low_mask
        low_freq = low_freq * self.learnable_filter.to(device)
        high_freq_spatial = torch.fft.ifft2(high_freq).real
        low_freq_spatial = torch.fft.ifft2(low_freq).real
        concatenated = torch.cat([high_freq_spatial, low_freq_spatial], dim=1)
        conv_out = self.conv(concatenated)
        attention_map = self.sigmoid(conv_out)
        attention_output = x * attention_map
        
        gap_out = self.gap(x)
        max_out = self.max_pool(x)
        combined = torch.cat([gap_out, max_out], dim=1)
        combined = self.fc(combined)
        combined = self.sigmoid(combined)
        x = x * combined
        output = x + attention_output
        return output
    
    def create_high_mask(self, size, device):
        mask = torch.ones(size, dtype=torch.cfloat, device=device)
        h, w = size[-2], size[-1]
        mask[:, :, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0
        return mask
    
    def create_low_mask(self, size, device):
        mask = torch.zeros(size, dtype=torch.cfloat, device=device)
        h, w = size[-2], size[-1]
        mask[:, :, h // 4:3 * h // 4, w // 4:3 * w // 4] = 1
        return mask



###########################################################
#                                                         #
#                      model                              #
#                                                         #
###########################################################
class RacoNet(nn.Module):
    def __init__(self, n_channels, n_classes, num_classes, output_size, bilinear=False):
        super(RacoNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.output_size = output_size

        self.dim_match_layer1 = DimensionMatchingLayer(96, 64)
        self.dim_match_layer2 = DimensionMatchingLayer(96, 128)
        self.dim_match_layer3 = DimensionMatchingLayer(192, 256)
        self.dim_match_layer4 = DimensionMatchingLayer(384, 512)
        self.dim_match_layer5 = DimensionMatchingLayer(768, 1024)

        self.cwf64 = FF(64, 64)
        self.cwf128 = FF(128, 128)
        self.cwf256 = FF(256, 256)
        self.cwf512 = FF(512, 512)
        self.cwf1024 = FF(1024, 1024)

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = (OutConv(64, n_classes))

        self.sclt = sclt_m1_1(pretrained=True, num_classes=1000)
        self.alt = timm.create_model(
            'maxxvit_rmlp_small_rw_256.sw_in1k',
            pretrained=False,
            features_only=True,
        )

        weight_path = './model/RACONET/ALT.pth'
        state_dict = torch.load(weight_path, weights_only=False)

        new_state_dict = {}
        for key, value in state_dict.items():
            new_key = key.replace('.', '_', 1)
            if new_key in ["stem_conv1.weight", "stem_norm1.weight", "stem_norm1.bias", "stem_conv2.weight"]:
                new_key = new_key.replace('_', '.', 1)
            new_state_dict[new_key] = value

        model_state_dict = self.alt.state_dict()
        extra_keys = set(new_state_dict.keys()) - set(model_state_dict.keys())
        for key in extra_keys:
            del new_state_dict[key]
        self.alt.load_state_dict(new_state_dict)
        
        self.stm_block64 = STMBlock(hidden_dim=64)
        self.stm_block128 = STMBlock(hidden_dim=128)
        self.stm_block256 = STMBlock(hidden_dim=256)
        self.stm_block512 = STMBlock(hidden_dim=512)
        self.stm_block1024 = STMBlock(hidden_dim=1024)

        self.fsa64 = FrequencySpatialAttention(64)
        self.fsa128 = FrequencySpatialAttention(128)
        self.fsa256 = FrequencySpatialAttention(256)
        self.fsa512 = FrequencySpatialAttention(512)
        self.fsa1024 = FrequencySpatialAttention(1024)

        self.kan_block1 = KANBlock(dim=512)
        self.kan_block2 = KANBlock(dim=256)
        self.kan_block3 = KANBlock(dim=128)
        self.kan_block4 = KANBlock(dim=64)

        self.final_kan = KANLinear(in_features=512 + 256 + 128, out_features=num_classes)

    def forward(self, x):
        y1, _ = self.sclt(x)
        alt_features = self.alt(x)
        e1, e2, e3, e4, e5 = alt_features[:5]

        e1 = self.dim_match_layer1(e1)
        e2 = self.dim_match_layer2(e2)
        e3 = self.dim_match_layer3(e3)
        e4 = self.dim_match_layer4(e4)
        e5 = self.dim_match_layer5(e5)

        y1[1] = self.fsa64(y1[1])
        y1[2] = self.fsa128(y1[2])
        y1[3] = self.fsa256(y1[3])
        y1[4] = self.fsa512(y1[4])
        y1[5] = self.fsa1024(y1[5])

        v1 = self.cwf64(y1[1], e1)
        v2 = self.cwf128(y1[2], e2)
        v3 = self.cwf256(y1[3], e3)
        v4 = self.cwf512(y1[4], e4)
        v5 = self.cwf1024(y1[5], e5)
        v5 = self.stm_block1024(v5)

        x = self.up1(v5, v4)
        x = self.stm_block512(x)
        B, C, H, W = x.shape
        x_reshaped = x.view(B, C, H * W).transpose(1, 2).contiguous()
        kan_feat1 = self.kan_block1(x_reshaped, H, W)
        kan_feat1_reshaped = kan_feat1.transpose(1, 2).view(B, C, H, W)
        kan_feat1_pooled = F.adaptive_avg_pool2d(kan_feat1_reshaped, (1, 1)).view(B, -1)

        x = self.up2(x, v3)
        x = self.stm_block256(x)
        B, C, H, W = x.shape
        x_reshaped = x.view(B, C, H * W).transpose(1, 2).contiguous()
        kan_feat2 = self.kan_block2(x_reshaped, H, W)
        kan_feat2_reshaped = kan_feat2.transpose(1, 2).view(B, C, H, W)
        kan_feat2_pooled = F.adaptive_avg_pool2d(kan_feat2_reshaped, (1, 1)).view(B, -1)

        x = self.up3(x, v2)
        x = self.stm_block128(x)
        B, C, H, W = x.shape
        x_reshaped = x.view(B, C, H * W).transpose(1, 2).contiguous()
        kan_feat3 = self.kan_block3(x_reshaped, H, W)
        kan_feat3_reshaped = kan_feat3.transpose(1, 2).view(B, C, H, W)
        kan_feat3_pooled = F.adaptive_avg_pool2d(kan_feat3_reshaped, (1, 1)).view(B, -1)

        x = self.up4(x, v1)
        x = self.stm_block64(x)
        if self.output_size is not None:
            x = F.interpolate(x, size=self.output_size, mode='bilinear', align_corners=True)
        else:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        img = self.outc(x)
        concat_features = torch.cat([kan_feat1_pooled, kan_feat2_pooled, kan_feat3_pooled], dim=1)
        logits = self.final_kan(concat_features)

        return img, logits
