import torch.nn as nn
import torch.nn.functional as F
import torch

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model,
                               out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn

class ChannelFusion(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels
        weights = torch.eye(num_channels) * 1  
        weights = weights + torch.randn(num_channels, num_channels) * 0
        self.weights = nn.Parameter(weights)
        
    def forward(self, x):
        # x shape: [B, C, d_model, num_patch]
        B, C, d_model, num_patch = x.shape
        
        weights = self.weights
        
        # Reshape for matrix multiplication
        x_reshaped = x.permute(0, 2, 3, 1).reshape(-1, C)  # [B*d_model*num_patch, C]
        
        # Apply channel fusion
        fused = torch.matmul(x_reshaped, weights)  # [B*d_model*num_patch, C]
        
        # Reshape back
        fused = fused.reshape(B, d_model, num_patch, C).permute(0, 3, 1, 2)
        
        return fused
    
class EnhancedChannelFusion(nn.Module):
    def __init__(self, d_model, d_ff, num_channels, activation='relu'):
        super().__init__()
        self.activation = F.relu if activation == "relu" else F.gelu
        self.channel_fusion = nn.Sequential(ChannelFusion(num_channels))  
        self.projection = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        # 输入形状: [B, C, D, P]
        # 通道维度融合
        x = self.channel_fusion(x)  # [B, C, D, P]
        
        # 投影变换
        # x = x.permute(0, 2, 3, 1)  # [B, D, P, C]
        x = self.projection(x)     # [B, D, P, C]
        return x.permute(0, 3, 1, 2)  # 恢复原始维度
    
class FusedEncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", num_channels=7):
        super().__init__()
        self.activation = F.relu if activation == "relu" else F.gelu
        d_ff = d_ff or 4 * d_model
        self.num_channels = num_channels
        self.attention = attention
        self.channel_fusion = EnhancedChannelFusion(
            d_model=d_model, 
            d_ff=d_ff,
            num_channels=num_channels,
            activation=activation
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.gate = nn.Parameter(torch.tensor([1.0]))
        self.conv1 = nn.Conv1d(in_channels=d_model,
                               out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1)

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # 自注意力部分
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        # 通道融合替代FFN
        x_fusion = self.channel_fusion(x.reshape(-1, self.num_channels, x.shape[-2], x.shape[-1]))  # [B, C, D, P] -> [B, C, D, P]
        x_fusion = x_fusion.reshape(-1, x.shape[-2], x.shape[-1])  # [B, D, P, C]
        #FFN
        x_ffn = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        x_ffn = self.dropout(self.conv2(x_ffn).transpose(-1, 1))
        
        y = self.gate * x_ffn + (1 - self.gate) * x_fusion
        # print(self.gate)
        # 残差连接
        x = x + self.dropout(y)
        
        return self.norm2(x), attn

class DecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model,
                               out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask,
            tau=tau, delta=None
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask,
            tau=tau, delta=delta
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm3(x + y)


class DecoderOnlyLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(DecoderOnlyLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model,
                               out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(
            conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


class Decoder(nn.Module):
    def __init__(self, layers, norm_layer=None, projection=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(layers)
        self.norm = norm_layer
        self.projection = projection

    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask,
                      cross_mask=cross_mask, tau=tau, delta=delta)

        if self.norm is not None:
            x = self.norm(x)

        if self.projection is not None:
            x = self.projection(x)
        return x


class DecoderOnly(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(DecoderOnly, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(
            conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


class TimerLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(TimerLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model,
                               out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None, use_kv_cache=False):
        new_x, attn = self.attention(
            x, x, x,
            n_vars=n_vars,
            n_tokens=n_tokens,
            attn_mask=attn_mask,
            tau=tau, delta=delta,
            use_kv_cache=use_kv_cache  # Pass the kv cache flag here
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


class TimerBlock(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(TimerBlock, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(
            conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None, use_kv_cache=False):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta, use_kv_cache=use_kv_cache)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, n_vars,
                                           n_tokens, tau=tau, delta=None, use_kv_cache=use_kv_cache)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, n_vars, n_tokens,
                                     attn_mask=attn_mask, tau=tau, delta=delta, use_kv_cache=use_kv_cache)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns


# class TimerLayer(nn.Module):
#     def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
#         super(TimerLayer, self).__init__()
#         d_ff = d_ff or 4 * d_model
#         self.attention = attention
#         self.conv1 = nn.Conv1d(in_channels=d_model,
#                                out_channels=d_ff, kernel_size=1)
#         self.conv2 = nn.Conv1d(
#             in_channels=d_ff, out_channels=d_model, kernel_size=1)
#         self.norm1 = nn.LayerNorm(d_model)
#         self.norm2 = nn.LayerNorm(d_model)
#         self.dropout = nn.Dropout(dropout)
#         self.activation = F.relu if activation == "relu" else F.gelu

#     def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None):
#         new_x, attn = self.attention(
#             x, x, x,
#             n_vars=n_vars,
#             n_tokens=n_tokens,
#             attn_mask=attn_mask,
#             tau=tau, delta=delta
#         )
#         x = x + self.dropout(new_x)

#         y = x = self.norm1(x)
#         y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
#         y = self.dropout(self.conv2(y).transpose(-1, 1))

#         return self.norm2(x + y), attn


# class TimerBlock(nn.Module):
#     def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
#         super(TimerBlock, self).__init__()
#         self.attn_layers = nn.ModuleList(attn_layers)
#         self.conv_layers = nn.ModuleList(
#             conv_layers) if conv_layers is not None else None
#         self.norm = norm_layer

#     def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None):
#         # x [B, L, D]
#         attns = []
#         if self.conv_layers is not None:
#             for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
#                 delta = delta if i == 0 else None
#                 x, attn = attn_layer(
#                     x, attn_mask=attn_mask, tau=tau, delta=delta)
#                 x = conv_layer(x)
#                 attns.append(attn)
#             x, attn = self.attn_layers[-1](x, n_vars,
#                                            n_tokens, tau=tau, delta=None)
#             attns.append(attn)
#         else:
#             for attn_layer in self.attn_layers:
#                 x, attn = attn_layer(x, n_vars, n_tokens,
#                                      attn_mask=attn_mask, tau=tau, delta=delta)
#                 attns.append(attn)

#         if self.norm is not None:
#             x = self.norm(x)

#         return x, attns
