
from einops import rearrange, repeat, einsum
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax

from torch.nn.modules.utils import _triple



def get_conv1d(in_channels, out_channels,
               kernel_size, stride=1,
               padding=None, dilation=1,
               groups=1, bias=True):
    if padding is None:
        pad = kernel_size // 2
    else:
        pad = padding
    return nn.Conv1d(in_channels, out_channels,
                     kernel_size=kernel_size,
                     stride=stride,
                     padding=pad,
                     dilation=dilation,
                     groups=groups,
                     bias=bias)

def get_bn1d(dim):
    """1D 版 BatchNorm"""
    return nn.LayerNorm(dim)

def auto_make_branches(kernel:int,
                       rf_target:int,
                       base_rf:int = 1,
                       max_extra_rf:int = 0):
    """
    - base_rf      : 主干 conv 已经覆盖的 RF
    - rf_target    : 希望至少达到的 RF
    - max_extra_rf : 允许在 rf_target 之上再冗余多少 RF
    仅当 (base_rf + Σ增量) < rf_target+max_extra_rf 时才继续加分支
    """
    k = kernel
    dilates, rf, d = [], base_rf, 1
    while rf < rf_target + max_extra_rf:
        dilates.append(d)
        rf += (k - 1) * d
        # ➜ 一旦 rf 已≥rf_target，且已经加了 ≥2 条分支，就停止
        if rf >= rf_target >= 2:
            break
        d *= 2
    sizes = [k] * len(dilates)
    return sizes, dilates, rf          # 返回最终 RF 供打印检查

class DilatedReparamBlock1d(nn.Module):
    def __init__(self, channels, kernel_size, seq_len,
                 rf_ratio=1.0, max_extra_rf=11):
        super().__init__()
        k = kernel_size
        self.main_conv = get_conv1d(channels, channels, k,
                                    padding=(k-1)//2,
                                    groups=channels, bias=True)
        self.main_bn   = get_bn1d(seq_len)

        # 已有的感受野
        base_rf = k
        rf_target = int(seq_len * rf_ratio)

        sizes, dilates, final_rf = auto_make_branches(
            kernel=k,
            rf_target=rf_target,
            base_rf=base_rf,
            max_extra_rf=max_extra_rf)

        print(f"[DilatedBlock] k={k}  baseRF={base_rf}  "
              f"branch dilations={dilates}  finalRF={final_rf}")

        self.branches = nn.ModuleList()
        for k_, r in zip(sizes, dilates):
            self.branches.append(
                nn.Sequential(
                    get_conv1d(channels, channels, k_,
                               padding=r*(k_-1)//2,
                               dilation=r,
                               groups=channels, bias=False),
                    get_bn1d(seq_len))
            )
    def forward(self, x):
        out = 0
        for br in self.branches:
            out = out + br(x)
        return out + x


class FreMLP1D(nn.Module):
    def __init__(self, nc,group ,expand=2):
        super().__init__()
        self.process = nn.Sequential(
            nn.Conv2d(nc, expand*nc, 1,groups=group),
            nn.GELU(),
            nn.Conv2d(expand*nc, nc, 1,groups=group),
        )
    def forward(self, x):
        # x: (B, C, T)
        x_freq = torch.fft.rfft(x, dim=-1,norm='ortho')      # → (B, C, T//2+1) 复数
        mag   = torch.log1p(torch.abs(x_freq))         # 幅值谱
        pha   =  torch.angle(x_freq)         # 相位谱
        mag2  = self.process(mag)              # 在频域上调整幅值
        # mag2  = self.freq_conv(mag2)
        real  = mag2 * torch.cos(pha)
        imag  = mag2 * torch.sin(pha)
        x_out = torch.fft.irfft(torch.complex(real, imag),
                               n=x.size(-1), dim=-1,norm='ortho')
        return x_out

class DmodelBlock(nn.Module):
    def __init__(self, large_size, dmodel, dff, nvars,size,group, drop=0.1,rf_ratio=1.0):

        super(DmodelBlock, self).__init__()
        self.dw2 = nn.Conv1d(nvars * dmodel, nvars * dmodel,groups=nvars * dmodel,kernel_size=7,padding='same'
                                        )  # 7 步局部
        self.dw1 = DilatedReparamBlock1d(nvars * dmodel, kernel_size=large_size,
                                         seq_len=size, rf_ratio=rf_ratio)
        self.act= nn.GELU()
        self.norm = nn.LayerNorm(size)


        self.ffn1pw1 = nn.Conv1d(in_channels=nvars * dmodel, out_channels=nvars * dff, kernel_size=1, stride=1,
                                 padding=0, dilation=1, groups=group)
        self.ffn1act = nn.GELU()
        self.ffn1pw2 = nn.Conv1d(in_channels=nvars * dff, out_channels=nvars * dmodel, kernel_size=1, stride=1,
                                 padding=0, dilation=1, groups=group)

        self.ffn1drop1 =nn.Dropout(drop)
        self.ffn1drop2 =nn.Dropout(drop)



    def forward(self,x):
        input = x
        B, M, D, N = x.shape# b c d t

        x = x.reshape(B,M*D,N)

        x = self.dw1(x)
        x = self.act(x)
        x = self.dw2(x)
        x = x.reshape(B,M,D,N)
        x = self.norm(x)
        # # #
        x = x.reshape(B,M*D,N)

        x = self.ffn1drop1(self.ffn1pw1(x))
        x = self.ffn1act(x)
        x = self.ffn1drop2(self.ffn1pw2(x))

        x = x.reshape(B, M, D, N)



        x = input + x
        return x

import torch
import torch.nn as nn
# ---------- util ----------
def pad_to_even(x, k, stride):
    in_len     = x.shape[-1]
    out_len    = in_len // stride
    needed_len = (out_len - 1) * stride + k
    pad_len    = max(0, needed_len - in_len)
    if pad_len > 0:
        pad = x[:, :, -1:].repeat(1, 1, pad_len)
        x   = torch.cat([x, pad], dim=-1).contiguous()
    return x

class MultiScaleStem(nn.Module):
    """
    kernel_sizes = [stride*4, stride*2, stride]
    stride == 2 -> [8,4,2]
    stride == 4 -> [16,8,4]
    """
    def __init__(
        self,
        seq_len,
        out_channels = 24,
        d_model = 64,
        group_emb = 4,
        stride = 4,
    ):
        super().__init__()

        self.group_emb = group_emb
        self.stride = stride
        # 动态核大小
        self.k1, self.k2, self.k3 = stride * 4, stride*2 , stride,

        k1, k2, k3 = self.k1, self.k2, self.k3

        # 数据 b c t -> b*c out * (mem)  b*c k -> out k + conv.weight      b*c t * k -> b*c out        M[B*C,1,K]  key B*C L-> B*C 1,K->B*C K K
        # out 1 k   -> out k
        self.stem_k1 = nn.Sequential(
            nn.Conv1d(1, out_channels, kernel_size=k1, stride=stride),
            nn.LayerNorm(seq_len // stride),
        )
        self.stem_k2 = nn.Sequential(
            nn.Conv1d(1, out_channels, kernel_size=k2, stride=stride),
            nn.LayerNorm(seq_len // stride),
        )
        self.stem_k3 = nn.Sequential(
            nn.Conv1d(1, out_channels, kernel_size=k3, stride=stride),
            nn.LayerNorm(seq_len // stride),
        )

        # 三个尺度拼接后用 1×1 卷积融合
        self.fusion = nn.Conv1d(
            out_channels * 3, d_model, kernel_size=1, padding=0, groups=1
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, C, L]
        B, C, L = x.shape
        k1, k2, k3 = self.k1,self.k2,self.k3
        x_k1 = pad_to_even(x, k1, self.stride)
        x_k2 = pad_to_even(x, k2, self.stride)
        x_k3 = pad_to_even(x, k3, self.stride)
        x_k1 = x_k1.unsqueeze(1).reshape(B * C, 1, x_k1.size(-1))
        x_k2 = x_k2.unsqueeze(1).reshape(B * C, 1, x_k2.size(-1))
        x_k3 = x_k3.unsqueeze(1).reshape(B * C, 1, x_k3.size(-1))

        f_k1 = self.stem_k1(x_k1)

        f_k2 = self.stem_k2(x_k2)

        f_k3 = self.stem_k3(x_k3)

        M1, S1 = 0,0
        M2, S2 = 0,0
        M3, S3 = 0,0

        out = torch.cat([f_k3, f_k2, f_k1], dim=1).contiguous()

        out = self.fusion(out)                     # [B*C, d_model, T_out]

        return out,[M1,S1,M2, S2,M3, S3]

import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.StandardNorm import Normalize
class Model(nn.Module):

    def __init__(self, configs):
        super(Model, self).__init__()
        self.configs = configs
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.down_sampling_window = configs.down_sampling_window
        self.channel_independence = configs.channel_independence
        self.drop = configs.dropout
        self.normalize_layer = Normalize(self.configs.enc_in, affine=True,
                                         non_norm=True if configs.use_norm == 0 else False)
        self.group = configs.enc_in
        # self.pool =nn.AdaptiveAvgPool1d((self.seq_len//configs.stride))
        # self.emb = nn.Parameter(torch.randn(1, configs.d_model))
        self.multi_stem = MultiScaleStem(seq_len=configs.seq_len,d_model=configs.d_model,out_channels=configs.out_channels,stride=configs.stride)
        self.fre0 =FreMLP1D(configs.enc_in,group=configs.enc_in)
        self.gamma0 = nn.Parameter(torch.zeros((1,configs.enc_in,1 , 1)), requires_grad=True)
        self.gamma1 = nn.Parameter(torch.zeros((1, configs.enc_in,1, 1)), requires_grad=True)
        self.blockbase =  DmodelBlock(configs.large_size, configs.d_model, configs.d_ff, configs.enc_in, size=self.seq_len//configs.stride,
                        drop=self.drop, rf_ratio=configs.rf_ratio,group=self.group)

        self.down = nn.Sequential(
            nn.Conv1d(configs.d_model, configs.d_model*2, kernel_size=4, stride=2,),
            )

        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            if configs.head_dropout > 0:
                self.head64 = nn.ModuleList([
                    nn.Sequential(
                        nn.Flatten(start_dim=-2),
                        nn.Linear(configs.d_model * 2 * (self.seq_len // configs.stride // 2), self.pred_len),
                        # 256 64 512/  720/4 /2/2
                        nn.Dropout(configs.head_dropout)
                    )for _ in range(1)
                ])

            else:
                self.head64 = nn.ModuleList([
                    nn.Sequential(
                        nn.Flatten(start_dim=-2),
                        nn.Linear(configs.d_model * 2 * (self.seq_len // configs.stride // 2), self.pred_len),
                        # 256 64 512/  720/4 /2/2
                    ) for _ in range(1)
                ])

        #842
        #ETTh1√  96  0.356   0.390   0.352 0.387       #ETTh2√  96  0.246   0.322   0.247 0.324     #ETTm1√  96  0.280 0.338  0.280 0.335
        #ETT    192 0.393   0.410   0.387 0.408        #ETT    192 0.297   0.360   0.296 0.361      #ETT    192 0.317 0.360  0.317 0.358
        #ETT    336 0.377   0.410   0.373 0.406        #ETT    192 0.297   0.360   0.296 0.361      #ETT    336 0.348 0.383  0.351 0.384
        #ETT    720 0.430   0.449   0.429 0.453        #ETT    720 0.375   0.422   0.379 0.424      #ETT    720 0.408 0.412  0.410 0.413

        # ETTm2√  96  0.160   0.250  0.160 0.250       # weather# 96  0.141 0.192           # ILI   # 24    1.292 0.712
        # ETT    192 0.213   0.290  0.215 0.289                 # 192 0.184 0.236                   # 36    1.150 0.682
        # ETT    336 0.268   0.325  0.270 0.326                 # 336 0.230 0.276                   # 48    1.151 0.704
        # ETT    720 0.345   0.378  0.346 0.379                 # 720 0.302 0.326                   # 60    1.375 0.796

        # exchange rate √
        #96     0.080 0.195
        #192    0.167 0.289
        #336    0.305 0.397
        #720    0.657 0.582

        # 消融实验
        # without emb without emb without emb

        # ETTm1√ 96  0.285 0.339        # ETTm2√ 96  0.165 0.255        # ETTh1  # 96    0.360 0.392
        # ETT    192 0.326 0.364        # ETT    192 0.217 0.292            # 192   0.394 0.412
        # ETT    336 0.352 0.384        # ETT    336 0.274 0.329            # 336   0.379 0.410
        # ETT    720 0.419 0.417        # ETT    720 0.350 0.383            # 720   0.431 0.455

        # ETTh2
        # 96    0.252   0.325        # weather  # 96    0.145   0.200        # ILI  # 24    1.600 0.830
        # 192   0.313   0.368               # 192   0.190   0.244               # 36    1.212 0.687
        # 336   0.308   0.370               # 336   0.235   0.282               # 48    1.416 0.795
        # 720   0.390   0.430               # 720   0.310   0.334               # 60    1.884 0.947

        #exchange rate
        #96     0.081   0.196
        #192    0.167   0.288
        #336    0.310   0.400
        #720    0.658   0.583

        # ETTm1: 0.345, 0.376
        # ETTm2: 0.252, 0.315
        # ETTh1: 0.391, 0.417
        # ETTh2: 0.316, 0.373
        # Weather: 0.220, 0.265
        # ILI: 1.528, 0.815
        # Exchange_Rate: 0.304, 0.367

        # without cnn  without cnn  without cnn

        # ETTm1√ 96  0.295 0.338        # ETTm2√ 96  0.164 0.254        # ETTh1 # 96    0.360 0.391
        # ETT    192 0.331 0.362        # ETT    192 0.215 0.289                # 192   0.400 0.415
        # ETT    336 0.372 0.389        # ETT    336 0.269 0.326                # 336   0.389 0.412
        # ETT    720 0.419 0.413        # ETT    720 0.348 0.379                # 720   0.444 0.464

        # ETTh2
        # 96    0.248   0.323        # weather  # 96    0.144   0.193        # ILI  # 24    1.690 0.894
        # 192   0.311   0.365                   # 192   0.187   0.236               # 36    1.578 0.804
        # 336   0.298   0.366                   # 336   0.229   0.274               # 48    1.743 0.881
        # 720   0.373   0.421                   # 720   0.302   0.324               # 60    1.919 0.977


        # ETTm1: 0.354 0.376
        # ETTm2: 0.249 0.312
        # ETTh1: 0.398  0.421
        # ETTh2: 0.308 0.369
        # Weather: 0.216 0.257
        # ILI: 1.733 0.889
        # Exchange_Rate:

        # without downsample

        # ETTm1√ 96     0.281 0.338
        # ETT    192    0.319 0.361
        # ETT    336    0.351 0.384
        # ETT    720    0.411 0.412

        # ETTm2√ 96     0.164 0.256
        # ETT    192    0.217 0.293
        # ETT    336    0.272 0.329
        # ETT    720    0.350 0.382

        # ETTh1√ 96     0.359 0.391
        # ETT    192    0.390 0.410
        # ETT    336    0.387 0.412
        # ETT    720    0.438 0.459

        # ETTh2√ 96     0.251 0.325
        # ETT    192    0.297 0.359
        # ETT    336    0.303 0.367
        # ETT    720    0.390 0.431

        # weather    96  0.142 0.192
        # weather    192 0.186 0.236
        # weather    336 0.233 0.280
        # weather    720 0.315 0.338

        # ILI  √
        # 24    1.374 0.752
        # 36    1.567 0.781
        # 48    1.149 0.705
        # 60    1.467 0.804
        # ETTm1: 0.341, 0.374
        # ETTm2: 0.251, 0.315
        # ETTh1: 0.394, 0.418
        # ETTh2: 0.310, 0.370
        # Weather: 0.219, 0.262
        # ILI: 1.389, 0.760
    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec,):
        x = self.normalize_layer(x_enc, 'norm') # B T C revin
        x= x.permute(0, 2, 1,).contiguous()
        B,C ,T =x.shape
        x, [M1, S1, M2, S2, M3, S3] = self.multi_stem(x)  # b*c s d l
        [M1, S1, M2, S2, M3, S3] =  [0, 0, 0,0, 0, 0]
        # x =self.pool(x)
        # x =  x.reshape(B*C,-1).unsqueeze(-1) * self.emb
        # x =x.permute(0, 2, 1)
        _, D_, N_ = x.shape
        x = x.reshape(B, C, D_, N_)
        z = self.fre0(x)
        x = z * self.gamma0 + x
        x_out = self.blockbase(x)
        # x_down = x_out
        # x_out =x
        x_down = x_out.reshape(B * C, D_, N_)
        x4 = pad_to_even(x_down, 4, 2)
        x4 = self.down(x4)
        x_down = x4.reshape(B, C, D_ * 2, N_ // 2)  #
        z1 =  self.fre0(x_down)* self.gamma1
        x_down = z1 + x_down
        x_out = self.head64[0](x_down )  # B C D T -> B C Pre_len ->graph conv
        # x_out = self.alpha * x_out + (1 - self.alpha) * self.correlation_embedding(x_out)
        x_out = x_out.permute(0, 2, 1).contiguous()

        dec_out = self.normalize_layer(x_out, 'denorm')

        if self.training:
            return dec_out, 0, [M1, S1, M2, S2, M3, S3]
        else:
            return dec_out, [M1, S1, M2, S2, M3, S3]
    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,mask=None):
        if self.task_name == 'long_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, )
            return dec_out
        if self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec, )
            return dec_out[0]
        if self.task_name == 'imputation':
            dec_out = self.imputation(x_enc, mask)
            return dec_out  # [B, L, D]
        if self.task_name == 'anomaly_detection':
            dec_out = self.anomaly_detection(x_enc,None,None,None)
            return dec_out  # [B, L, D]
        if self.task_name == 'classification':
            dec_out = self.classification(x_enc, x_mark_enc,None,None)
            return dec_out  # [B, N]
        else:
            raise ValueError('Other tasks implemented yet')




def stride_generator(N, reverse=False):
    strides = [1, 2] * 10
    if reverse:
        return list(reversed(strides[:N]))
    else:
        return strides[:N]

