import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.KANlayers import FKANLayer
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath


class FKANBlock(nn.Module): # Refer to the Fourier-KAN Block in Figure 2 of the paper
    def __init__(self, dim, in_features, out_features, gridsize, num_layers, drop):
        super().__init__()
        self.dim = dim
        self.in_features = in_features
        self.out_features = out_features
        self.gridsize = gridsize # Hyperparameter "g" in Appendix (cf. Appendix F)
        self.num_layers = num_layers # Hyperparameter "n" in Appendix (cf. Appendix F)
        self.drop = drop # Hyperparameter "DropPath" in Appendix (cf. Appendix F)
        self.net = nn.Sequential()
        for _ in range(self.num_layers):
            self.net.append(
                FKANLayer(in_features=self.in_features,
                          out_features=self.out_features,
                          gridsize=self.gridsize,
                          add_bias=False)
            )
        # DropPath (Huang et al. 2016) to enhance generalization (cf. Figure 2)
        self.net.append(DropPath(self.drop))

    def forward(self, x):
        x = torch.transpose(x, self.dim, -1)
        x = self.net(x) + x
        x = torch.transpose(x, self.dim, -1)
        return x


class FKANMixerEncoder(nn.Module): # Refer to the Fourier-KAN Mixer Encoder in Figure 2 of the paper
    def __init__(self, win_size, num_chn, patch_size, gridsize, num_layers, norm, drop):
        super().__init__()
        self.win_size = win_size # Hyperparameter "Window Size" in Appendix (cf. Appendix F)
        self.num_chn = num_chn
        self.patch_size = patch_size
        self.patch_num = win_size // patch_size
        self.gridsize = gridsize # Hyperparameter "g" in Appendix (cf. Appendix F)
        self.num_layers = num_layers # Hyperparameter "n" in Appendix (cf. Appendix F)
        self.drop = drop # Hyperparameter "DropPath" in Appendix (cf. Appendix F)
        
        self.net = nn.Sequential()
        # Channel-KAN Block
        channel_kan_block = FKANBlock(1, self.num_chn, self.num_chn, self.gridsize, self.num_layers, self.drop)
        # Patch-KAN Block
        patch_kan_block = FKANBlock(2, self.patch_num, self.patch_num, self.gridsize, self.num_layers, self.drop)
        # Temporal-KAN Block
        temporal_kan_block = FKANBlock(3, self.patch_size, self.patch_size, self.gridsize, self.num_layers, self.drop)
        # Projection
        projection = nn.Linear(self.patch_size, 1)
        
        if norm == 'BN':
            channel_norm = nn.BatchNorm2d(self.num_chn)
            patch_norm = nn.BatchNorm2d(self.num_chn)
            temporal_norm = nn.BatchNorm2d(self.num_chn)
        elif norm == 'IN':
            channel_norm = nn.InstanceNorm2d(self.num_chn)
            patch_norm = nn.InstanceNorm2d(self.num_chn)
            temporal_norm = nn.InstanceNorm2d(self.num_chn)
        elif norm == 'LN':
            channel_norm = nn.LayerNorm(self.patch_size)
            patch_norm = nn.LayerNorm(self.patch_size)
            temporal_norm = nn.LayerNorm(self.patch_size)

        # Split input sequence into patches of length patch_size (cf. Eq (5))
        self.net.append(Rearrange("b c (l1 l2) -> b c l1 l2", l2=patch_size)) 
        self.net.append(channel_norm)
        self.net.append(channel_kan_block) # Channel-wise Mixing via Channel-KAN Block (cf. Eq (9))
        self.net.append(patch_norm)
        self.net.append(patch_kan_block) # Patch-wise Mixing via Patch-KAN Block (cf. Eq (10))
        self.net.append(temporal_norm)
        self.net.append(temporal_kan_block) # Temporal-wise Mixing via Temporal-KAN Block (cf. Eq (11))
        self.net.append(projection) 
        self.net.append(Rearrange("b c l1 1 -> b c l1")) # Reduce Dim (cf. Figure 2)

    def forward(self, x):
        return self.net(x)


class FKANMixerDecoder(nn.Module): # Refer to the Fourier-KAN Mixer Decoder in Figure 2 of the paper
    def __init__(self, win_size, num_chn, patch_size, gridsize, num_layers, norm, drop):
        super().__init__()
        self.win_size = win_size # Hyperparameter "Window Size" in Appendix (cf. Appendix F)
        self.num_chn = num_chn
        self.patch_size = patch_size
        self.patch_num = win_size // patch_size
        self.gridsize = gridsize # Hyperparameter "g" in Appendix (cf. Appendix F)
        self.num_layers = num_layers # Hyperparameter "n" in Appendix (cf. Appendix F)
        self.drop = drop # Hyperparameter "DropPath" in Appendix (cf. Appendix F)
        
        self.net = nn.Sequential()
        # Channel-KAN Block
        channel_kan_block = FKANBlock(1, self.num_chn, self.num_chn, self.gridsize, self.num_layers, self.drop)
        # Patch-KAN Block
        patch_kan_block = FKANBlock(2, self.patch_num, self.patch_num, self.gridsize, self.num_layers, self.drop)
        # Temporal-KAN Block
        temporal_kan_block = FKANBlock(3, self.patch_size, self.patch_size, self.gridsize, self.num_layers, self.drop)
        # Projection
        projection = nn.Linear(1, self.patch_size)
        
        if norm == 'BN':
            temporal_norm = nn.BatchNorm2d(self.num_chn)
            patch_norm = nn.BatchNorm2d(self.num_chn)
            channel_norm = nn.BatchNorm2d(self.num_chn)
        elif norm == 'IN':
            temporal_norm = nn.InstanceNorm2d(self.num_chn)
            patch_norm = nn.InstanceNorm2d(self.num_chn)
            channel_norm = nn.InstanceNorm2d(self.num_chn)
        elif norm == 'LN':
            temporal_norm = nn.LayerNorm(self.patch_size)
            patch_norm = nn.LayerNorm(self.patch_size)
            channel_norm = nn.LayerNorm(self.patch_size)
        
        self.net.append(Rearrange("b c l1 -> b c l1 1")) # Expand Dim (cf. Figure 2)
        self.net.append(projection)
        self.net.append(temporal_norm)
        self.net.append(temporal_kan_block) # Temporal-wise Mixing via Temporal-KAN Block (cf. Eq (11))
        self.net.append(patch_norm)
        self.net.append(patch_kan_block) # Patch-wise Mixing via Patch-KAN Block (cf. Eq (10))
        self.net.append(channel_norm)
        self.net.append(channel_kan_block) # Channel-wise Mixing via Channel-KAN Block (cf. Eq (9))
        self.net.append(Rearrange("b c l1 l2 -> b c (l1 l2)")) # Inverse Patching(=UnPatch)

    def forward(self, x):
        return self.net(x)


class KANomaly(nn.Module):
    def __init__(self, win_size, num_chn, patch_sizes, gridsize, num_layers, norm, drop):
        super().__init__()
        self.win_size = win_size # Hyperparameter "Window Size" in Appendix (cf. Appendix F)
        self.num_chn = num_chn
        self.patch_sizes = patch_sizes # Hyperparameter "Patch Size" in Appendix (cf. Appendix F)
        self.gridsize = gridsize # Hyperparameter "g" in Appendix (cf. Appendix F)
        self.num_layers = num_layers # Hyperparameter "n" in Appendix (cf. Appendix F)
        self.norm = norm # Hyperparameter "Norm" in Appendix (cf. Appendix F)
        self.drop = drop # Hyperparameter "DropPath" in Appendix (cf. Appendix F)
        self.fkan_mixer_encoders = nn.ModuleList()
        self.fkan_mixer_decoders = nn.ModuleList()
        self.paddings = []
        for i, patch_size in enumerate(patch_sizes): # Multi-Scale Patching Strategy (cf. Figure 2 & Eq (6))
            res = win_size % patch_size
            padding = (patch_size - res) % patch_size
            self.paddings.append(padding)
            padded_len = win_size + padding
            self.fkan_mixer_encoders.append( # Fourier-KAN Mixer Encoder (cf. Figure 2 & Eq (13))
                FKANMixerEncoder(padded_len, num_chn, patch_size, gridsize, num_layers, norm, drop))
            self.fkan_mixer_decoders.append( # Fourier-KAN Mixer Decoder (cf. Figure 2 & Eq (14))
                FKANMixerDecoder(padded_len, num_chn, patch_size, gridsize, num_layers, norm, drop))

    def forward(self, x): # X ∈ [Batch, Length, Channel]
        x = rearrange(x, "b l c -> b c l") # X ∈ [Batch, Channel, Length]
        for i in range(len(self.patch_sizes)):
            x_in = x
            x_in = F.pad(x_in, (self.paddings[i], 0), "constant", 0)
            Z = self.fkan_mixer_encoders[i](x_in)
            comp = self.fkan_mixer_decoders[i](Z)[:, :, self.paddings[i]:]
            x = comp + x # Residual connection (He et al. 2016) (cf. Figure 2 & Eq (15))
        return x.permute(0,2,1)