import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Autoformer_EncDec import series_decomp
from layers.Embed import DataEmbedding_wo_pos
from layers.StandardNorm import Normalize

import numpy as np
import random


class DFT_series_decomp_1(nn.Module):
    """
    Series decomposition block - Compatible with torch 1.3
    """

    def __init__(self, top_k=5):
        super(DFT_series_decomp_1, self).__init__()
        self.top_k = top_k

    def forward(self, x):
        # 保存原始设备和数据类型
        device = x.device
        dtype = x.dtype

        # 转换到numpy进行FFT操作
        x_np = x.detach().cpu().numpy()

        # 执行实数FFT
        xf = np.fft.rfft(x_np, axis=-1)
        freq = np.abs(xf)

        # 将DC分量设为0
        freq[..., 0] = 0

        # 找到top-k频率分量
        # 由于torch 1.3没有topk for complex arrays，我们在numpy中处理
        batch_size, seq_len, freq_len = freq.shape

        # 为每个batch和序列找到top_k频率
        xf_filtered = np.zeros_like(xf)

        for b in range(batch_size):
            for s in range(seq_len):
                # 获取当前序列的频率幅度
                current_freq = freq[b, s, :]

                # 找到top_k个最大频率的索引
                if len(current_freq) <= self.top_k:
                    # 如果频率数量小于等于top_k，保留所有非零频率
                    xf_filtered[b, s, :] = xf[b, s, :]
                else:
                    # 找到top_k个最大值的索引
                    top_k_indices = np.argpartition(
                        current_freq, -self.top_k)[-self.top_k:]

                    # 只保留top_k频率分量
                    xf_filtered[b, s, top_k_indices] = xf[b, s, top_k_indices]

        # 执行逆FFT得到季节性分量
        x_season_np = np.fft.irfft(xf_filtered, n=x_np.shape[-1], axis=-1)

        # 转换回torch tensor
        x_season = torch.from_numpy(x_season_np).to(device).type(dtype)

        # 趋势分量 = 原始数据 - 季节性分量
        x_trend = x - x_season

        return x_season, x_trend


# 替代版本：使用torch 1.3兼容的方法（如果不想依赖numpy）
class DFT_series_decomp_2(nn.Module):
    """
    Series decomposition block - Pure torch 1.3 implementation
    使用手动实现的DFT（较慢但完全兼容）
    """

    def __init__(self, top_k=5):
        super(DFT_series_decomp_2, self).__init__()
        self.top_k = top_k

    def dft(self, x):
        """手动实现的DFT"""
        N = x.size(-1)
        n = torch.arange(N, dtype=torch.float32, device=x.device)
        k = n.view(-1, 1)

        # 创建DFT矩阵
        angle = -2j * np.pi * k * n / N
        W = torch.exp(torch.tensor(angle.real, device=x.device)) * torch.cos(torch.tensor(angle.imag, device=x.device)) + \
            1j * torch.exp(torch.tensor(angle.real, device=x.device)) * \
            torch.sin(torch.tensor(angle.imag, device=x.device))

        # 由于torch 1.3对复数支持有限，我们分别处理实部和虚部
        x_real = x
        x_imag = torch.zeros_like(x)

        # 手动矩阵乘法
        dft_real = torch.matmul(W.real, x_real) - torch.matmul(W.imag, x_imag)
        dft_imag = torch.matmul(W.real, x_imag) + torch.matmul(W.imag, x_real)

        return dft_real, dft_imag

    def idft(self, dft_real, dft_imag, n):
        """手动实现的逆DFT"""
        N = dft_real.size(-1)
        k = torch.arange(N, dtype=torch.float32, device=dft_real.device)
        n_idx = k.view(-1, 1)

        # 创建逆DFT矩阵
        angle = 2j * np.pi * k * n_idx / N
        W_inv = torch.exp(torch.tensor(angle.real, device=dft_real.device)) * torch.cos(torch.tensor(angle.imag, device=dft_real.device)) + \
            1j * torch.exp(torch.tensor(angle.real, device=dft_real.device)) * \
            torch.sin(torch.tensor(angle.imag, device=dft_real.device))

        # 逆变换
        x_real = (torch.matmul(W_inv.real, dft_real) -
                  torch.matmul(W_inv.imag, dft_imag)) / N

        return x_real[:n] if n < N else x_real

    def forward(self, x):
        # 注意：这个纯torch版本会比较慢，建议使用上面的numpy版本
        print("警告：使用纯torch实现的DFT，速度较慢。建议使用numpy版本。")

        # 简化版本：使用移动平均来近似季节性分解
        # 这不是真正的DFT，但在torch 1.3中更实用

        # 使用移动平均来提取趋势
        kernel_size = min(self.top_k * 2 + 1, x.size(-1) // 4)
        if kernel_size < 3:
            kernel_size = 3

        # 创建移动平均核
        kernel = torch.ones(1, 1, kernel_size, device=x.device) / kernel_size

        # 对每个特征维度应用移动平均
        x_trend = torch.zeros_like(x)
        for i in range(x.size(1)):
            # 添加padding以保持序列长度
            padded = torch.nn.functional.pad(x[:, i:i+1, :].unsqueeze(1),
                                             (kernel_size//2, kernel_size//2),
                                             mode='reflect')
            x_trend[:, i:i+1, :] = torch.nn.functional.conv1d(
                padded, kernel, padding=0).squeeze(1)

        # 季节性分量 = 原始数据 - 趋势分量
        x_season = x - x_trend

        return x_season, x_trend


class DFT_series_decomp(nn.Module):
    """
    Series decomposition block
    """

    def __init__(self, top_k=5):
        super(DFT_series_decomp, self).__init__()
        self.top_k = top_k

    def forward(self, x):
        xf = torch.fft.rfft(x)
        freq = abs(xf)
        freq[0] = 0
        top_k_freq, top_list = torch.topk(freq, self.top_k)
        xf[freq <= top_k_freq.min()] = 0
        x_season = torch.fft.irfft(xf, n=x.shape[-2])
        x_trend = x - x_season
        return x_season, x_trend


class MultiScaleSeasonMixing(nn.Module):
    """
    Bottom-up mixing season pattern
    """

    def __init__(self, configs):
        super(MultiScaleSeasonMixing, self).__init__()

        self.down_sampling_layers = torch.nn.ModuleList(
            [
                nn.Sequential(
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** i),
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                    ),
                    nn.GELU(),
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                    ),

                )
                for i in range(configs.down_sampling_layers)
            ]
        )

    def forward(self, season_list):

        # mixing high->low
        out_high = season_list[0]
        out_low = season_list[1]
        out_season_list = [out_high.permute(0, 2, 1)]

        for i in range(len(season_list) - 1):
            out_low_res = self.down_sampling_layers[i](out_high)
            out_low = out_low + out_low_res
            out_high = out_low
            if i + 2 <= len(season_list) - 1:
                out_low = season_list[i + 2]
            out_season_list.append(out_high.permute(0, 2, 1))

        return out_season_list


class MultiScaleTrendMixing(nn.Module):
    """
    Top-down mixing trend pattern
    """

    def __init__(self, configs):
        super(MultiScaleTrendMixing, self).__init__()

        self.up_sampling_layers = torch.nn.ModuleList(
            [
                nn.Sequential(
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** (i + 1)),
                        configs.seq_len // (configs.down_sampling_window ** i),
                    ),
                    nn.GELU(),
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** i),
                        configs.seq_len // (configs.down_sampling_window ** i),
                    ),
                )
                for i in reversed(range(configs.down_sampling_layers))
            ])

    def forward(self, trend_list):

        # mixing low->high
        trend_list_reverse = trend_list.copy()
        trend_list_reverse.reverse()
        out_low = trend_list_reverse[0]
        out_high = trend_list_reverse[1]
        out_trend_list = [out_low.permute(0, 2, 1)]

        for i in range(len(trend_list_reverse) - 1):
            out_high_res = self.up_sampling_layers[i](out_low)
            out_high = out_high + out_high_res
            out_low = out_high
            if i + 2 <= len(trend_list_reverse) - 1:
                out_high = trend_list_reverse[i + 2]
            out_trend_list.append(out_low.permute(0, 2, 1))

        out_trend_list.reverse()
        return out_trend_list


class PastDecomposableMixing(nn.Module):
    def __init__(self, configs):
        super(PastDecomposableMixing, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.down_sampling_window = configs.down_sampling_window

        self.layer_norm = nn.LayerNorm(configs.d_model)
        self.dropout = nn.Dropout(configs.dropout)
        self.channel_independence = configs.channel_independence

        self.mixed_strategies = configs.mixed_strategies

        if configs.decomp_method == 'moving_avg':
            self.decompsition = series_decomp(configs.moving_avg)
        elif configs.decomp_method == "dft_decomp":
            self.decompsition = DFT_series_decomp_1(configs.top_k)
        elif configs.decomp_method == "copy":  # Add this condition for direct copying
            # Define a simple lambda function or a separate method that just returns the input twice
            self.decompsition = lambda x: (x, x)
        elif configs.decomp_method == "global_avg":  # 新增的全局均值分解方法
            def global_mean_decompose_func(x):
                # 使用 PyTorch 的 mean 而不是 numpy 的 mean
                mean = x.mean(dim=-1, keepdim=True)  # 支持多维输入
                trend = mean.expand_as(x)  # 扩展成和 x 同样的形状
                season = x - trend
                return season, trend
            self.decompsition = lambda x: global_mean_decompose_func(x)
        else:
            raise ValueError(
                'decomp_method is not recognized. Expected "moving_avg", "dft_decomp", or "copy".')

        if configs.channel_independence == 0:
            self.cross_layer = nn.Sequential(
                nn.Linear(in_features=configs.d_model,
                          out_features=configs.d_ff),
                nn.GELU(),
                nn.Linear(in_features=configs.d_ff,
                          out_features=configs.d_model),
            )

        # Mixing season
        self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs)

        # Mxing trend
        self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs)

        self.out_cross_layer = nn.Sequential(
            nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
            nn.GELU(),
            nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
        )

    def forward(self, x_list):
        length_list = []
        for x in x_list:
            _, T, _ = x.size()
            length_list.append(T)

        # Decompose to obtain the season and trend
        season_list = []
        trend_list = []
        for x in x_list:
            season, trend = self.decompsition(x)
            if self.channel_independence == 0:
                season = self.cross_layer(season)
                trend = self.cross_layer(trend)

            season_list.append(season.permute(0, 2, 1))
            trend_list.append(trend.permute(0, 2, 1))

        if self.mixed_strategies == 'season-trend':
            # print("mixed_strategies: season-trend")
            out_season_list = self.mixing_multi_scale_season(season_list)
            out_trend_list = self.mixing_multi_scale_trend(trend_list)
        elif self.mixed_strategies == 'trend-season':
            # print("mixed_strategies: trend-season")
            out_season_list = self.mixing_multi_scale_trend(season_list)
            out_trend_list = self.mixing_multi_scale_season(trend_list)
        elif self.mixed_strategies == 'season':
            # print("mixed_strategies: season")
            out_season_list = self.mixing_multi_scale_season(season_list)
            out_trend_list = []
            for item in trend_list:
                out_trend_list.append(item.permute(0, 2, 1))
        elif self.mixed_strategies == 'trend':
            # print("mixed_strategies: trend")
            out_season_list = []
            for item in season_list:
                out_season_list.append(item.permute(0, 2, 1))
            out_trend_list = self.mixing_multi_scale_trend(trend_list)
        elif self.mixed_strategies == 'none':
            # print("mixed_strategies: none")
            out_season_list = []
            for item in season_list:
                out_season_list.append(item.permute(0, 2, 1))

            out_trend_list = []
            for item in trend_list:
                out_trend_list.append(item.permute(0, 2, 1))

        out_list = []
        for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list, length_list):
            out = out_season + out_trend
            if self.channel_independence:
                out = ori + self.out_cross_layer(out)
            out_list.append(out[:, :length, :])

        return out_list


class Model(nn.Module):

    def __init__(self, configs):
        super(Model, self).__init__()
        self.configs = configs
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.down_sampling_window = configs.down_sampling_window
        self.channel_independence = configs.channel_independence
        self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs)
                                         for _ in range(configs.e_layers)])

        self.preprocess = series_decomp(configs.moving_avg)
        self.enc_in = configs.enc_in
        self.use_future_temporal_feature = configs.use_future_temporal_feature

        if self.channel_independence == 1:
            # print(self.channel_independence)
            self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq,
                                                      configs.dropout)
        else:
            self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
                                                      configs.dropout)

        self.layer = configs.e_layers

        self.normalize_layers = torch.nn.ModuleList(
            [
                Normalize(self.configs.enc_in, affine=True,
                          non_norm=True if configs.use_norm == 0 else False)
                for i in range(configs.down_sampling_layers + 1)
            ]
        )

        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            self.predict_layers = torch.nn.ModuleList(
                [
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** i),
                        configs.pred_len,
                    )
                    for i in range(configs.down_sampling_layers + 1)
                ]
            )

            if self.channel_independence == 1:
                self.projection_layer = nn.Linear(
                    configs.d_model, 1, bias=True)
            else:
                self.projection_layer = nn.Linear(
                    configs.d_model, configs.c_out, bias=True)

                self.out_res_layers = torch.nn.ModuleList([
                    torch.nn.Linear(
                        configs.seq_len // (configs.down_sampling_window ** i),
                        configs.seq_len // (configs.down_sampling_window ** i),
                    )
                    for i in range(configs.down_sampling_layers + 1)
                ])

                self.regression_layers = torch.nn.ModuleList(
                    [
                        torch.nn.Linear(
                            configs.seq_len // (configs.down_sampling_window ** i),
                            configs.pred_len,
                        )
                        for i in range(configs.down_sampling_layers + 1)
                    ]
                )
        if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
            if self.channel_independence == 1:
                self.projection_layer = nn.Linear(
                    configs.d_model, 1, bias=True)
            else:
                self.projection_layer = nn.Linear(
                    configs.d_model, configs.c_out, bias=True)
        if self.task_name == 'classification':
            self.act = F.gelu
            self.dropout = nn.Dropout(configs.dropout)
            self.projection = nn.Linear(
                configs.d_model * configs.seq_len, configs.num_class)

        # 添加device属性到configs
        if hasattr(self, 'device'):
            configs.device = self.device

        # 初始化动态尺度选择器
        self.scale_selector = DynamicScaleSelection(configs)
        self.adaptive_ensemble = AdaptivePredictorEnsemble(configs)

    def out_projection(self, dec_out, i, out_res):
        dec_out = self.projection_layer(dec_out)
        out_res = out_res.permute(0, 2, 1)
        out_res = self.out_res_layers[i](out_res)
        out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
        dec_out = dec_out + out_res
        return dec_out

    def pre_enc(self, x_list):
        if self.channel_independence == 1:
            # 确保返回格式一致
            return (x_list, [None] * len(x_list))
        else:
            out1_list = []
            out2_list = []
            for x in x_list:
                x_1, x_2 = self.preprocess(x)
                out1_list.append(x_1)
                out2_list.append(x_2)
            return (out1_list, out2_list)

    def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
        # 添加动态尺度选择
        scale_weights = self.scale_selector(x_enc)  # [B, max_scales]
        # print(scale_weights)

        if self.configs.down_sampling_method == 'adaptive':
            x_enc = x_enc.permute(0, 2, 1)  # [B, C, T]
            x_enc_ori = x_enc
            x_mark_enc_ori = x_mark_enc

            x_enc_sampling_list = []
            x_mark_sampling_list = []
            x_enc_sampling_list.append(x_enc.permute(0, 2, 1))  # [B, T, C]
            x_mark_sampling_list.append(x_mark_enc)

            for i in range(self.configs.down_sampling_layers):
                # 计算输出大小，处理边界情况
                output_size = max(
                    1, self.seq_len // (self.configs.down_sampling_window ** (i + 1)))
                adaptive_pool = nn.AdaptiveAvgPool1d(output_size)
                x_enc_sampling = adaptive_pool(
                    x_enc_ori)  # [B, C, output_size]

                # 获取对应尺度的权重
                scale_weight = scale_weights[:, i]  # [B]

                # 扩展维度以便广播
                scale_weight = scale_weight.unsqueeze(
                    1).unsqueeze(2)  # [B, 1, 1]

                # 判断是否启用权重控制
                if self.configs.use_scale_weighting == 1:
                    # print("use_scale_weighting == 1")
                    x_enc_sampling = x_enc_sampling * scale_weight  # 应用权重
                else:
                    # print("use_scale_weighting == 0")
                    x_enc_sampling = x_enc_sampling  # 不应用权重，保持原始值

                x_enc_sampling_list.append(
                    x_enc_sampling.permute(0, 2, 1))  # [B, output_size, C]
                x_enc_ori = x_enc_sampling

                if x_mark_enc_ori is not None:
                    # 对时间标记使用相同的adaptive pooling策略
                    # 先转换维度: [B, T, mark_dim] -> [B, mark_dim, T]
                    x_mark_temp = x_mark_enc_ori.permute(0, 2, 1)
                    # 对每个标记维度分别进行adaptive pooling
                    mark_dim = x_mark_temp.shape[1]
                    x_mark_sampling_list_temp = []
                    for dim in range(mark_dim):
                        mark_adaptive_pool = nn.AdaptiveAvgPool1d(output_size)
                        mark_sampling = mark_adaptive_pool(
                            x_mark_temp[:, dim:dim+1, :])  # [B, 1, output_size]
                        x_mark_sampling_list_temp.append(mark_sampling)

                    # 合并所有标记维度: [B, mark_dim, output_size]
                    x_mark_sampling = torch.cat(
                        x_mark_sampling_list_temp, dim=1)
                    # 转换回原始维度: [B, mark_dim, output_size] -> [B, output_size, mark_dim]
                    x_mark_sampling = x_mark_sampling.permute(0, 2, 1)
                    x_mark_sampling_list.append(x_mark_sampling)
                    x_mark_enc_ori = x_mark_sampling

            return x_enc_sampling_list, x_mark_sampling_list if x_mark_enc is not None else None

        else:
            if self.configs.down_sampling_method == 'max':
                down_pool = torch.nn.MaxPool1d(
                    self.configs.down_sampling_window, return_indices=False)
            elif self.configs.down_sampling_method == 'avg':
                down_pool = torch.nn.AvgPool1d(
                    self.configs.down_sampling_window)
            elif self.configs.down_sampling_method == 'conv':
                padding = 1 if torch.__version__ >= '1.5.0' else 2
                down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
                                      kernel_size=3, padding=padding,
                                      stride=self.configs.down_sampling_window,
                                      padding_mode='circular',
                                      bias=False)
            else:
                return x_enc, x_mark_enc

            # B,T,C -> B,C,T
            x_enc = x_enc.permute(0, 2, 1)

            x_enc_ori = x_enc
            x_mark_enc_ori = x_mark_enc

            x_enc_sampling_list = []
            x_mark_sampling_list = []
            x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
            x_mark_sampling_list.append(x_mark_enc)

            # 累积下采样倍数
            cumulative_stride = 1

            for i in range(self.configs.down_sampling_layers):
                x_enc_sampling = down_pool(x_enc_ori)
                x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
                x_enc_ori = x_enc_sampling

                # 更新累积下采样倍数
                cumulative_stride *= self.configs.down_sampling_window

                if x_mark_enc_ori is not None:
                    # 使用累积步长确保与x_enc的长度一致
                    # 计算实际的输出长度（考虑边界情况）
                    actual_length = x_enc_sampling.shape[2]  # 从实际的x_enc输出获取长度

                    # 确保不超出原始序列长度
                    if cumulative_stride < x_mark_enc.shape[1]:
                        # 使用均匀采样来匹配长度
                        indices = torch.linspace(
                            0, x_mark_enc.shape[1] - 1, actual_length, dtype=torch.long)
                        x_mark_sampling = x_mark_enc[:, indices, :]
                    else:
                        # 如果累积步长超出序列，取最后一个时间点
                        x_mark_sampling = x_mark_enc[:, -1:,
                                                     :].expand(-1, actual_length, -1)

                    x_mark_sampling_list.append(x_mark_sampling)
                    x_mark_enc_ori = x_mark_sampling

            return x_enc_sampling_list, x_mark_sampling_list if x_mark_enc is not None else None

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        if self.use_future_temporal_feature:
            if self.channel_independence == 1:
                B, T, N = x_enc.size()
                x_mark_dec = x_mark_dec.repeat(N, 1, 1)
                self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
            else:
                self.x_mark_dec = self.enc_embedding(None, x_mark_dec)

        x_enc, x_mark_enc = self.__multi_scale_process_inputs(
            x_enc, x_mark_enc)

        x_list = []
        x_mark_list = []
        if x_mark_enc is not None:
            for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
                B, T, N = x.size()
                # 保存原始形状用于后续denorm
                original_shape = x.shape
                x = self.normalize_layers[i](x, 'norm')
                if self.channel_independence == 1:
                    x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
                    x_mark = x_mark.repeat(N, 1, 1)
                x_list.append(x)
                x_mark_list.append(x_mark)
        else:
            for i, x in zip(range(len(x_enc)), x_enc):
                B, T, N = x.size()
                # 保存原始形状用于后续denorm
                original_shape = x.shape
                x = self.normalize_layers[i](x, 'norm')
                if self.channel_independence == 1:
                    x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
                x_list.append(x)

        # embedding
        enc_out_list = []

        # # --- REFINED DEBUGGING CODE START (Before pre_enc) ---
        # print("\n--- Debugging x_list (Before pre_enc) ---")
        # print(f"Type of x_list: {type(x_list)}")
        # print(f"Length of x_list: {len(x_list)}")
        # # This loop iterates through *every element* in x_list
        # for idx, item in enumerate(x_list):
        #     print(f"  x_list[{idx}] Type: {type(item)}")
        #     if isinstance(item, torch.Tensor):
        #         print(f"  x_list[{idx}] Shape: {item.shape}")
        #     elif isinstance(item, list): # Handles cases where x_list might contain nested lists
        #         print(f"  x_list[{idx}] is a list of length: {len(item)}")
        #         if len(item) > 0 and isinstance(item[0], torch.Tensor):
        #             print(f"  x_list[{idx}][0] Shape: {item[0].shape}")
        #     # If 'item' is neither a Tensor nor a list, you'd see its type here
        # print("------------------------------------------")
        # # --- REFINED DEBUGGING CODE END ---

        # # --- REFINED DEBUGGING CODE START (Before pre_enc) ---
        # print("\n--- Debugging x_mark_list (Before pre_enc) ---")
        # print(f"Type of x_mark_list: {type(x_mark_list)}")
        # print(f"Length of x_mark_list: {len(x_mark_list)}")
        # # This loop iterates through *every element* in x_mark_list
        # for idx, item in enumerate(x_mark_list):
        #     print(f"  x_mark_list[{idx}] Type: {type(item)}")
        #     if isinstance(item, torch.Tensor):
        #         print(f"  x_mark_list[{idx}] Shape: {item.shape}")
        #     elif isinstance(item, list): # Handles cases where x_mark_list might contain nested lists
        #         print(f"  x_mark_list[{idx}] is a list of length: {len(item)}")
        #         if len(item) > 0 and isinstance(item[0], torch.Tensor):
        #             print(f"  x_mark_list[{idx}][0] Shape: {item[0].shape}")
        #     # If 'item' is neither a Tensor nor a list, you'd see its type here
        # print("------------------------------------------")
        # # --- REFINED DEBUGGING CODE END ---

        x_list = self.pre_enc(x_list)
        if x_mark_enc is not None:
            for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
                # print(f"x: {x.shape}")
                # print(f"x_mark: {x_mark.shape}")
                enc_out = self.enc_embedding(x, x_mark)  # [B,T,C]
                enc_out_list.append(enc_out)
        else:
            for i, x in zip(range(len(x_list[0])), x_list[0]):
                enc_out = self.enc_embedding(x, None)  # [B,T,C]
                enc_out_list.append(enc_out)

        # Past Decomposable Mixing as encoder for past
        for i in range(self.layer):
            enc_out_list = self.pdm_blocks[i](enc_out_list)

        # for i, tensor in enumerate(enc_out_list):
        #     print(f"Tensor {i}: shape = {tensor.shape}")

        if self.configs.use_ensemble_aggregation == 1:
            # print("use_ensemble_aggregation == 1")
            # Future Multipredictor Mixing as decoder for future
            dec_out = self.future_multi_mixing(B, enc_out_list, x_list)

            # 确保dec_out的形状与原始输入匹配
            # print(dec_out.shape)
            # print(original_shape)
            if dec_out.shape != original_shape:
                if self.channel_independence == 1:
                    dec_out = dec_out.reshape(B, N, -1).permute(0, 2, 1)

            # 使用第一个normalize层进行denorm，因为它包含了原始数据的统计信息
            dec_out = self.normalize_layers[0](dec_out, 'denorm')
        else:
            # print("use_ensemble_aggregation == 0")
            # Future Multipredictor Mixing as decoder for future
            dec_out_list = self.future_multi_mixing_none(
                B, enc_out_list, x_list)

            dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
            dec_out = self.normalize_layers[0](dec_out, 'denorm')
        return dec_out

    def future_multi_mixing_none(self, B, enc_out_list, x_list):
        dec_out_list = []
        if self.channel_independence == 1:
            x_list = x_list[0]
            for i, enc_out in zip(range(len(x_list)), enc_out_list):
                dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
                    0, 2, 1)  # align temporal dimension
                if self.use_future_temporal_feature:
                    dec_out = dec_out + self.x_mark_dec
                    dec_out = self.projection_layer(dec_out)
                else:
                    dec_out = self.projection_layer(dec_out)
                dec_out = dec_out.reshape(
                    B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous()
                dec_out_list.append(dec_out)

        else:
            for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
                dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
                    0, 2, 1)  # align temporal dimension
                dec_out = self.out_projection(dec_out, i, out_res)
                dec_out_list.append(dec_out)

        return dec_out_list

    def future_multi_mixing(self, B, enc_out_list, x_list):
        dec_out_list = []
        features_list = []

        # # 🔍 打印输入调试信息
        # print("=== DEBUG: enc_out_list ===")
        # for idx, eo in enumerate(enc_out_list):
        #     print(f"  enc_out_list[{idx}]: shape = {eo.shape}")

        # print("=== DEBUG: x_list ===")
        # if isinstance(x_list, (list, tuple)):
        #     for i, xl in enumerate(x_list):
        #         if isinstance(xl, (list, tuple)):
        #             print(f"  x_list[{i}] is a list/tuple of length {len(xl)}")
        #             for j, sub in enumerate(xl):
        #                 print(f"    x_list[{i}][{j}]: shape = {getattr(sub, 'shape', type(sub))}")
        #         else:
        #             print(f"  x_list[{i}]: shape = {getattr(xl, 'shape', type(xl))}")
        # else:
        #     print("x_list is not a list or tuple!")

        # === DEBUG: enc_out_list ===
        #   enc_out_list[0]: shape = torch.Size([16, 16, 32])
        #   enc_out_list[1]: shape = torch.Size([16, 8, 32])
        #   enc_out_list[2]: shape = torch.Size([16, 4, 32])
        #   enc_out_list[3]: shape = torch.Size([16, 2, 32])
        # === DEBUG: x_list ===
        #   x_list[0] is a list/tuple of length 4
        #     x_list[0][0]: shape = torch.Size([16, 16, 7])
        #     x_list[0][1]: shape = torch.Size([16, 8, 7])
        #     x_list[0][2]: shape = torch.Size([16, 4, 7])
        #     x_list[0][3]: shape = torch.Size([16, 2, 7])
        #   x_list[1] is a list/tuple of length 4
        #     x_list[1][0]: shape = torch.Size([16, 16, 7])
        #     x_list[1][1]: shape = torch.Size([16, 8, 7])
        #     x_list[1][2]: shape = torch.Size([16, 4, 7])
        #     x_list[1][3]: shape = torch.Size([16, 2, 7])

        if self.channel_independence == 1:
            for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
                dec_out = self.predict_layers[i](
                    enc_out.permute(0, 2, 1)).permute(0, 2, 1)
                dec_out = self.projection_layer(dec_out)
                dec_out_list.append(dec_out)
                features_list.append(enc_out)
        else:
            for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
                dec_out = self.predict_layers[i](
                    enc_out.permute(0, 2, 1)).permute(0, 2, 1)
                dec_out = self.out_projection(dec_out, i, out_res)
                dec_out_list.append(dec_out)
                features_list.append(enc_out)

        if not features_list:
            raise ValueError("features_list is empty")

        # for i, tensor in enumerate(dec_out_list):
        #     print(f"dec_out_list[{i}] shape: {tensor.shape}")

        # for i, tensor in enumerate(features_list):
        #     print(f"features_list[{i}] shape: {tensor.shape}")

        # 打印维度信息以进行调试
        # for i, feat in enumerate(features_list):
        #     print(f"Feature {i} shape: {feat.shape}")

        # 保持原始维度，不进行插值
        # aligned_features = []
        # for feat in features_list:
        #     aligned_features.append(feat)
        aligned_features = features_list

        # 在时间维度上填充较短的特征
        max_len = max(feat.size(1) for feat in aligned_features)
        padded_features = []

        for feat in aligned_features:
            if feat.size(1) < max_len:
                # 计算需要填充的数量
                pad_size = max_len - feat.size(1)
                # 使用最后一个时间步的值进行填充
                pad = feat[:, -1:, :].repeat(1, pad_size, 1)
                feat = torch.cat([feat, pad], dim=1)
            padded_features.append(feat)

        # 在特征维度上连接填充后的特征
        features = torch.cat(padded_features, dim=2)  # [B, T, C*num_scales]

        # for i, tensor in enumerate(dec_out_list):
        #     print(f"dec_out_list[{i}] shape: {tensor.shape}")

        # 使用自适应集成
        final_prediction, _ = self.adaptive_ensemble(dec_out_list, features)

        if len(dec_out_list) == 1:
            return dec_out_list[0]

        return final_prediction

    def classification(self, x_enc, x_mark_enc):
        x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
        x_list = x_enc

        # embedding
        enc_out_list = []
        for x in x_list:
            enc_out = self.enc_embedding(x, None)  # [B,T,C]
            enc_out_list.append(enc_out)

        # MultiScale-CrissCrossAttention  as encoder for past
        for i in range(self.layer):
            enc_out_list = self.pdm_blocks[i](enc_out_list)

        enc_out = enc_out_list[0]
        # Output
        # the output transformer encoder/decoder embeddings don't include non-linearity
        output = self.act(enc_out)
        output = self.dropout(output)
        # zero-out padding embeddings
        output = output * x_mark_enc.unsqueeze(-1)
        # (batch_size, seq_length * d_model)
        output = output.reshape(output.shape[0], -1)
        output = self.projection(output)  # (batch_size, num_classes)
        return output

    def anomaly_detection(self, x_enc):
        B, T, N = x_enc.size()
        x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)

        x_list = []

        for i, x in zip(range(len(x_enc)), x_enc, ):
            B, T, N = x.size()
            x = self.normalize_layers[i](x, 'norm')
            if self.channel_independence == 1:
                x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
            x_list.append(x)

        # embedding
        enc_out_list = []
        for x in x_list:
            enc_out = self.enc_embedding(x, None)  # [B,T,C]
            enc_out_list.append(enc_out)

        # MultiScale-CrissCrossAttention  as encoder for past
        for i in range(self.layer):
            enc_out_list = self.pdm_blocks[i](enc_out_list)

        dec_out = self.projection_layer(enc_out_list[0])
        dec_out = dec_out.reshape(
            B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()

        dec_out = self.normalize_layers[0](dec_out, 'denorm')
        return dec_out

    def imputation(self, x_enc, x_mark_enc, mask):
        means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
        means = means.unsqueeze(1).detach()
        x_enc = x_enc - means
        x_enc = x_enc.masked_fill(mask == 0, 0)
        stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
                           torch.sum(mask == 1, dim=1) + 1e-5)
        stdev = stdev.unsqueeze(1).detach()
        x_enc /= stdev

        B, T, N = x_enc.size()
        x_enc, x_mark_enc = self.__multi_scale_process_inputs(
            x_enc, x_mark_enc)

        x_list = []
        x_mark_list = []
        if x_mark_enc is not None:
            for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
                B, T, N = x.size()
                if self.channel_independence == 1:
                    x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
                x_list.append(x)
                x_mark = x_mark.repeat(N, 1, 1)
                x_mark_list.append(x_mark)
        else:
            for i, x in zip(range(len(x_enc)), x_enc, ):
                B, T, N = x.size()
                if self.channel_independence == 1:
                    x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
                x_list.append(x)

        # embedding
        enc_out_list = []
        for x in x_list:
            enc_out = self.enc_embedding(x, None)  # [B,T,C]
            enc_out_list.append(enc_out)

        # MultiScale-CrissCrossAttention  as encoder for past
        for i in range(self.layer):
            enc_out_list = self.pdm_blocks[i](enc_out_list)

        dec_out = self.projection_layer(enc_out_list[0])
        dec_out = dec_out.reshape(
            B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()

        dec_out = dec_out * \
            (stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
        dec_out = dec_out + \
            (means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out
        if self.task_name == 'imputation':
            dec_out = self.imputation(x_enc, x_mark_enc, mask)
            return dec_out  # [B, L, D]
        if self.task_name == 'anomaly_detection':
            dec_out = self.anomaly_detection(x_enc)
            return dec_out  # [B, L, D]
        if self.task_name == 'classification':
            dec_out = self.classification(x_enc, x_mark_enc)
            return dec_out  # [B, N]
        else:
            raise ValueError('Other tasks implemented yet')


class DynamicScaleSelection(nn.Module):
    def __init__(self, configs):
        super(DynamicScaleSelection, self).__init__()
        self.d_model = configs.d_model
        self.max_scales = configs.down_sampling_layers
        self.device = self._get_device(configs)

        # 特征提取器
        self.feature_extractor = nn.Sequential(
            nn.Linear(configs.seq_len * configs.enc_in, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, configs.d_model)
        ).to(self.device)

        # 尺度权重生成器
        self.scale_weight_generator = nn.Sequential(
            nn.Linear(configs.d_model, self.max_scales),
            nn.Softmax(dim=-1)
        ).to(self.device)

        # 时间特征分析
        self.temporal_attention = nn.MultiheadAttention(
            configs.d_model,
            num_heads=4,
            dropout=configs.dropout
        ).to(self.device)

    def _get_device(self, configs):
        if hasattr(configs, 'use_gpu') and configs.use_gpu:
            if hasattr(configs, 'device'):
                return configs.device
            else:
                return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        return torch.device('cpu')

    def forward(self, x):
        # 确保输入张量在正确的设备上
        x = x.to(self.device)

        B, T, C = x.size()
        x_flat = x.reshape(B, -1)
        features = self.feature_extractor(x_flat)
        features = features.unsqueeze(1)

        temporal_features, _ = self.temporal_attention(
            features, features, features
        )

        scale_weights = self.scale_weight_generator(
            temporal_features.squeeze(1)
        )

        return scale_weights


class AdaptivePredictorEnsemble(nn.Module):
    def __init__(self, configs):
        super(AdaptivePredictorEnsemble, self).__init__()
        self.d_model = configs.d_model
        self.pred_len = configs.pred_len
        self.device = self._get_device(configs)
        self.aggregation_strategy = configs.aggregation_strategy
        # 根据实际特征维度计算输入维度
        # 从打印信息可以看到特征形状为：
        # [112, 96, 16], [112, 48, 16], [112, 24, 16], [112, 12, 16]
        self.total_feature_dim = configs.seq_len * \
            configs.d_model * (configs.down_sampling_layers + 1)

        hidden_dim = 512  # 增加隐藏层维度以处理更大的输入

        # 评估网络
        self.performance_evaluator = nn.Sequential(
            nn.Linear(self.total_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(configs.dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(configs.dropout),
            nn.Linear(hidden_dim // 2, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, 1)
        ).to(self.device)

        # 不确定性估计网络
        self.uncertainty_estimator = nn.Sequential(
            nn.Linear(self.total_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(configs.dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(configs.dropout),
            nn.Linear(hidden_dim // 2, configs.d_model),
            nn.ReLU(),
            nn.Linear(configs.d_model, 1),
            nn.Softplus()
        ).to(self.device)

    def _get_device(self, configs):
        if hasattr(configs, 'use_gpu') and configs.use_gpu:
            if hasattr(configs, 'device'):
                return configs.device
            else:
                return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        return torch.device('cpu')

    def forward(self, predictions, features):
        features = features.to(self.device)
        predictions = [pred.to(self.device) for pred in predictions]

        B = features.size(0)
        features_flat = features.reshape(B, -1)  # [B, T*C]

        # 验证维度是否匹配
        assert features_flat.shape[1] == self.total_feature_dim, "Mismatch in feature dimensions"

        performance_scores = []
        uncertainties = []

        for pred in predictions:
            score = self.performance_evaluator(features_flat)
            uncertainty = self.uncertainty_estimator(features_flat)

            performance_scores.append(score)
            uncertainties.append(uncertainty)

        weights = torch.cat(performance_scores, dim=1)  # [B, num_predictions]
        uncertainties = torch.cat(uncertainties, dim=1)  # [B, num_predictions]

        weights = weights * torch.exp(-uncertainties)
        weights = F.softmax(weights, dim=1)

        # 选择尺度策略
        if self.aggregation_strategy == 'coarsest':  # 最粗（最小尺度）
            selected_indices = [0]
        elif self.aggregation_strategy == 'finest':  # 最细（最大尺度）
            selected_indices = [len(predictions) - 1]
        elif self.aggregation_strategy == 'random':  # 随机选几个尺度
            selected_indices = random.sample(
                range(len(predictions)), k=2)  # 可根据需求调整 k
        elif self.aggregation_strategy == 'all':     # 全部尺度
            selected_indices = list(range(len(predictions)))
        else:
            raise ValueError(
                f"Unsupported SCALE_STRATEGY: {configs.SCALE_STRATEGY}")

        final_prediction = torch.zeros_like(predictions[0], device=self.device)
        # for i, pred in enumerate(predictions):
        #     final_prediction += pred * weights[:, i:i+1].unsqueeze(-1)
        for i in selected_indices:
            final_prediction += predictions[i] * \
                weights[:, i:i+1].unsqueeze(-1)

        return final_prediction, weights
