import torch
from torch import nn
import torch.nn.functional as F
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import PatchEmbedding
import math

# class Expert(nn.Module):
#     """每个专家可以是任意网络结构，这里用最简单的 Linear."""
#     def __init__(self, in_len, out_len):
#         super().__init__()
#         self.linear = nn.Linear(in_len, out_len)

#     def forward(self, x):
#         # x shape: [B, D, in_len]
#         # 这里简单地做一下 permute + linear + permute 回来
#         out = self.linear(x.permute(0, 2, 1))  # -> [B, in_len, D]
#         out = out.permute(0, 2, 1)            # -> [B, out_len, D]
#         return out

# class Router(nn.Module):
#     """路由器(或门控网络)根据输入特征，输出对每个专家的得分/权重."""
#     def __init__(self, d_in, num_experts, hidden_dim=64):
#         super().__init__()
#         self.num_experts = num_experts
#         # 一个简单的两层 MLP 来预测专家权重
#         self.mlp = nn.Sequential(
#             nn.Linear(d_in, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, num_experts)
#         )

#     def forward(self, x_summary):
#         """
#         x_summary: [B, d_in], 用来代表当前batch的输入特征(可根据需要提取)
#         返回 gating_weights: [B, num_experts]
#         """
#         logits = self.mlp(x_summary)                 # [B, num_experts]
#         gating_weights = F.softmax(logits, dim=-1)   # 变成概率分布
#         return gating_weights

# class Model(nn.Module):
#     def __init__(self, configs):
#         super().__init__()
#         self.pred_len = configs.test_pred_len
#         self.seq_len = configs.seq_len

#         # 3个专家, 你可以把它们对应到 Linear1/2/3 的功能
#         self.expert1 = Expert(self.seq_len, self.pred_len)
#         self.expert2 = Expert(self.seq_len//2, self.pred_len)
#         self.expert3 = Expert(self.seq_len//4, self.pred_len)

#         self.num_experts = 3
#         # 路由器: 输入维度可以是 seq_len，或者别的抽取方式
#         self.router = Router(d_in=self.seq_len, num_experts=self.num_experts)
#         self.use_norm = True

#     def forward(self, x, x_mark, y_mark):
#         """
#         x_full: [B, D, seq_len]
#         x_half: [B, D, seq_len//2]
#         x_quarter: [B, D, seq_len//4]
#         """
#         if self.use_norm:
#             # Normalization from Non-stationary Transformer
#             means = x.mean(1, keepdim=True).detach()
#             x = x - means
#             stdev = torch.sqrt(
#                 torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
#             x /= stdev
#         x_full, x_half, x_quarter = x, x[:, - self.seq_len // 2:, :], x[:, - self.seq_len // 4:, :]

#         # 1) 计算各专家的输出
#         out1 = self.expert1(x_full)      # [B, pred_len, D]
#         out2 = self.expert2(x_half)
#         out3 = self.expert3(x_quarter)

#         # 2) Router 生成 gating weights
#         #    这里很随意地取 x_full 在特征维做平均 => [B, seq_len] => 作为 router 的输入
#         x_summary = x_full.mean(dim=2)   # [B, seq_len]
#         gating_weights = self.router(x_summary)  # [B, 3]

#         # 3) 融合专家输出: out = α1 * out1 + α2 * out2 + α3 * out3
#         # gating_weights: [B, 3]
#         # outX: [B, pred_len, D]
#         # => 先 reshape gating_weights => [B, 3, 1, 1]
#         gw = gating_weights.view(-1, self.num_experts, 1, 1)  # [B, 3, 1, 1]

#         outs = torch.stack([out1, out2, out3], dim=1)  # [B, 3, pred_len, D]
#         out = gw * outs                                # [B, 3, pred_len, D]
#         dec_out = out.sum(dim=1)                           # [B, pred_len, D]

#         if self.use_norm:
#             # De-Normalization from Non-stationary Transformer
#             dec_out = dec_out * \
#                     (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
#             dec_out = dec_out + \
#                     (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

#         return dec_out

class Model(nn.Module):
    """
    Paper link: https://arxiv.org/abs/2211.14730
    The implementation of Moment is basically consistent with patchtst.
    """
    def __init__(self, configs):
        super().__init__()
        self.pred_len = configs.test_pred_len
        self.seq_len = configs.seq_len

        self.Linear1 = nn.Linear(self.seq_len, self.pred_len)
        self.Linear2 = nn.Linear(self.seq_len // 2, self.pred_len)
        self.Linear3 = nn.Linear(self.seq_len // 4, self.pred_len)
        self.router = nn.Linear(7, 3)  # configs.c_in = D，即通道数
        self.use_norm = True
        
    def forecast(self, x, x_mark, y_mark):
        if self.use_norm:
            tail_len = self.seq_len
            x_tail = x[:, -tail_len:, :]  # 取后1/2时序
            means = x_tail.mean(dim=1, keepdim=True).detach()
            stdev = torch.sqrt(torch.var(x_tail, dim=1, keepdim=True, unbiased=False) + 1e-5)
    
            # 对整条序列做减均值 & 除标准差
            x = (x - means) / stdev
    
        # 与原始代码相同：转置到 (B, D, T)
        x = x.permute(0, 2, 1)
    
        x_mean_pool = x.std(dim=-1)  # (B, D)
        w_raw = self.router(x_mean_pool)  # (B, 3)
        # print(w_raw)
        w = F.softmax(w_raw, dim=-1)      # (B, 3), 和为1
        # print(w)
        w1 = w[:, 0].unsqueeze(1).unsqueeze(2)  # (B, 1, 1)
        w2 = w[:, 1].unsqueeze(1).unsqueeze(2)
        w3 = w[:, 2].unsqueeze(1).unsqueeze(2)
        
        # 在转置后做线性映射
        dec_out = w1 * self.Linear1(x) \
                + w2 * self.Linear2(x[:, :, -self.seq_len // 2:]) \
                + w3 * self.Linear3(x[:, :, -self.seq_len // 4:])
        
        dec_out = dec_out.permute(0, 2, 1)
    
        if self.use_norm:
            # 反归一化时，仍然使用刚才统计到的 means 和 stdev
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    
        return dec_out

    def forward(self, x, x_mark, y_mark):
        dec_out = self.forecast(x, x_mark, y_mark)
        return dec_out[:, -self.pred_len:, :]  # [B, L, D]


# class Model(nn.Module):
#     """
#     Paper link: https://arxiv.org/abs/2211.14730
#     The implementation of Moment is basically consistent with patchtst.
#     """
#     def __init__(self, configs):
#         super().__init__()
#         self.pred_len = configs.test_pred_len
#         self.seq_len = configs.seq_len

#         self.Linear1 = nn.Linear(self.seq_len, self.pred_len)
#         self.Linear2 = nn.Linear(self.seq_len // 2, self.pred_len)
#         self.Linear3 = nn.Linear(self.seq_len // 4, self.pred_len)
#         # 定义可学习的标量门控参数（初始值可以随意）
#         self.w1 = nn.Parameter(torch.tensor(0.33))
#         self.w2 = nn.Parameter(torch.tensor(0.33))
#         self.w3 = nn.Parameter(torch.tensor(0.33))
#         self.use_norm = True
        
#     def forecast(self, x, x_mark, y_mark):
#         if self.use_norm:
#             # 只用后1/2序列来计算mean和std
#             half_len = self.seq_len
#             x_tail = x[:, -half_len:, :]  # 取后1/2时序
#             means = x_tail.mean(dim=1, keepdim=True).detach()
#             stdev = torch.sqrt(torch.var(x_tail, dim=1, keepdim=True, unbiased=False) + 1e-5)
    
#             # 对整条序列做减均值 & 除标准差
#             x = (x - means) / stdev
    
#         # 与原始代码相同：转置到 (B, D, T)
#         x = x.permute(0, 2, 1)
    
#         # 在转置后做线性映射
#         dec_out = self.w1 * self.Linear1(x) \
#                 + self.w2 * self.Linear2(x[:, :, -self.seq_len // 2:]) \
#                 + self.w3 * self.Linear3(x[:, :, -self.seq_len // 4:])
        
#         dec_out = dec_out.permute(0, 2, 1)
    
#         if self.use_norm:
#             # 反归一化时，仍然使用刚才统计到的 means 和 stdev
#             dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
#             dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    
#         return dec_out

    # def forecast(self, x, x_mark, y_mark):
    #     if self.use_norm:
    #         # Normalization from Non-stationary Transformer
    #         means = x.mean(1, keepdim=True).detach()
    #         x = x - means
    #         stdev = torch.sqrt(
    #             torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
    #         x /= stdev
    #     # do patching and embedding
    #     x = x.permute(0, 2, 1)

    #     # print(x.shape, self.seq_len)
    #     dec_out = self.Linear1(x) + self.Linear2(x[:, :, - self.seq_len // 2:]) + self.Linear3(x[:, :, - self.seq_len // 4:])
    #     # 用标量门控参数加权
    #     dec_out = self.w1 * self.Linear1(x) + self.w2 * self.Linear2(x[:, :, - self.seq_len // 2:]) + self.w3 * self.Linear3(x[:, :, - self.seq_len // 4:])
    #     dec_out = dec_out.permute(0, 2, 1)

    #     if self.use_norm:
    #         # De-Normalization from Non-stationary Transformer
    #         dec_out = dec_out * \
    #                 (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    #         dec_out = dec_out + \
    #                 (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    #     return dec_out

    # def forward(self, x, x_mark, y_mark):
    #     dec_out = self.forecast(x, x_mark, y_mark)
    #     return dec_out[:, -self.pred_len:, :]  # [B, L, D]


    # def forecast(self, x, x_mark, y_mark):
    #     x_full, x_half, x_quarter = x, x[:, - self.seq_len // 2:, :], x[:, - self.seq_len // 4:, :]
    #     # 对 x_full 做全段 mean/std
    #     mean_full = x_full.mean(dim=1, keepdim=True)
    #     std_full = torch.sqrt(
    #         torch.var(x_full, dim=1, keepdim=True, unbiased=False) + 1e-5)
    #     x_full_normed = (x_full - mean_full) / std_full
    
    #     # 对 x_half 做半段 mean/std
    #     mean_half = x_half.mean(dim=1, keepdim=True)
    #     std_half = torch.sqrt(
    #         torch.var(x_half, dim=1, keepdim=True, unbiased=False) + 1e-5)
    #     x_half_normed = (x_half - mean_half) / std_half
    
    #     # 对 x_quarter 做四分之一段 mean/std
    #     mean_quarter = x_quarter.mean(dim=1, keepdim=True)
    #     std_quarter = torch.sqrt(
    #         torch.var(x_quarter, dim=1, keepdim=True, unbiased=False) + 1e-5)
    #     x_quarter_normed = (x_quarter - mean_quarter) / std_quarter
    
    #     # 分别做线性映射
    #     y_full = self.Linear1(x_full_normed.permute(0, 2, 1))
    #     y_half = self.Linear2(x_half_normed.permute(0, 2, 1))
    #     y_quarter = self.Linear3(x_quarter_normed.permute(0, 2, 1))
    
    #     # 做反归一化
    #     y_full = y_full.permute(0, 2, 1) * std_full + mean_full
    #     y_half = y_half.permute(0, 2, 1) * std_half + mean_half
    #     y_quarter = y_quarter.permute(0, 2, 1) * std_quarter + mean_quarter
    
    #     # 最终再融合(比如相加) 
    #     out = self.w1 * y_full + self.w2 * y_half + self.w3 * y_quarter

    #     return out


# class Model(nn.Module):
#     """
#     Paper link: https://arxiv.org/abs/2211.14730
#     The implementation of Moment is basically consistent with patchtst.
#     """
#     def __init__(self, configs):
#         super().__init__()
#         self.pred_len = configs.test_pred_len
#         self.seq_len = configs.seq_len
#         self.use_norm = True

#         val = self.seq_len / 96.0
#         self.num_linears = int(math.ceil(math.log2(val))) + 1
#         self.num_linears = 3

#         # 使用 nn.ModuleList 存放 N 个线性层
#         self.linears = nn.ModuleList([
#             nn.Linear(self.seq_len // (2 ** factor), self.pred_len)
#             for factor in range(self.num_linears)
#         ])

#         # 定义可学习的 N 个标量门控权重，初始可以平均分配 1/N
#         init_w = [1.0 / self.num_linears] * self.num_linears
#         self.gates = nn.Parameter(torch.tensor(init_w, dtype=torch.float32))
    
#     def forward(self, x, x_mark, y_mark):
#         """
#         x shape: (B, T, D) - (batch, seq_len, channel_dim)
#         """
#         if self.use_norm:
#             # 只用后 1/4 序列来计算 mean 和 std
#             tail_len = self.seq_len
#             x_tail = x[:, -tail_len:, :] / 2  # 取后 1/4 时间序列
#             means = x_tail.mean(dim=1, keepdim=True).detach()
#             stdev = torch.sqrt(torch.var(x_tail, dim=1, keepdim=True, unbiased=False) + 1e-5)

#             # 对整条序列做减均值 & 除标准差
#             x = (x - means) / stdev

#         # 与原始代码相同：转置到 (B, D, T)
#         x = x.permute(0, 2, 1)

#         # 遍历所有线性层，并根据 gates[idx] 进行加权
#         dec_out = 0
#         for idx, linear_layer in enumerate(self.linears):
#             factor = 2 ** idx
#             # 这里举例：取后 seq_len//factor 的时间片段
#             x_sub = x[:, :, - (self.seq_len // factor):]
#             dec_out = dec_out + self.gates[idx] * linear_layer(x_sub)
        
#         # dec_out shape: (B, D, pred_len)
#         dec_out = dec_out.permute(0, 2, 1)  # 转回 (B, pred_len, D)

#         if self.use_norm:
#             # 反归一化时，仍然使用刚才统计到的 means 和 stdev
#             dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
#             dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

#         return dec_out
