import torch
import math
from torch import nn
from ts_benchmark.baselines.dualmoe.layers.Transformer_EncDec import EncoderLayer, FusedEncoderLayer
from ts_benchmark.baselines.dualmoe.layers.SelfAttention_Family import FullAttention, AttentionLayer
from ts_benchmark.baselines.dualmoe.layers.Embed import PatchEmbedding
import time  # 顶部导入time模块
from torch_pca import PCA

class EMA(nn.Module):
    """
    Exponential Moving Average (EMA) block to highlight the trend of time series
    """
    def __init__(self, alpha):
        super(EMA, self).__init__()
        # self.alpha = nn.Parameter(alpha)    # Learnable alpha
        self.alpha = alpha

    # Optimized implementation with O(1) time complexity
    def forward(self, x):
        # x: [Batch, Input, Channel]
        # self.alpha.data.clamp_(0, 1)        # Clamp learnable alpha to [0, 1]
        _, t, _ = x.shape
        powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,))
        weights = torch.pow((1 - self.alpha), powers).to('cuda')
        divisor = weights.clone()
        weights[1:] = weights[1:] * self.alpha
        weights = weights.reshape(1, t, 1)
        divisor = divisor.reshape(1, t, 1)
        x = torch.cumsum(x * weights, dim=1)
        x = torch.div(x, divisor)
        return x.to(torch.float32)

class DECOMP(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, alpha=0.3, beta=0.3):
        super(DECOMP, self).__init__()
        self.ma = EMA(alpha)

    def forward(self, x):
        moving_average = self.ma(x)
        res = x - moving_average
        return res, moving_average

class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(
            conv_layers) if conv_layers is not None else None
        self.norm = norm_layer
        self.pca = [PCA(n_components=configs.c_in-1) for i in range(configs.batch_size)]

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)
        return x, attns

class Transpose(nn.Module):
    def __init__(self, *dims, contiguous=False): 
        super().__init__()
        self.dims, self.contiguous = dims, contiguous
    def forward(self, x):
        if self.contiguous: return x.transpose(*self.dims).contiguous()
        else: return x.transpose(*self.dims)

class FlattenHead(nn.Module):
    def __init__(self, nf, target_window, head_dropout=0):
        super().__init__()
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):  # x: [bs x nvars x d_model x patch_num]
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x
    
class ChannelFusion(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels
        weights = torch.eye(num_channels) * 1  
        weights = weights + torch.randn(num_channels, num_channels) * 0
        self.weights = nn.Parameter(weights)
        
    def forward(self, x):
        # x shape: [B, C, d_model, num_patch]
        B, C, d_model, num_patch = x.shape
        
        weights = self.weights
        # weights = weights / torch.norm(weights, p=1, dim=0, keepdim=True)
        # mask = (torch.abs(weights) >= 0.01).float()
        # weights = weights * mask  # 此处会保留梯度，但被mask的权重梯度为0
        # Reshape for matrix multiplication
        x_reshaped = x.permute(0, 2, 3, 1).reshape(-1, C)  # [B*d_model*num_patch, C]
        
        # Apply channel fusion
        fused = torch.matmul(x_reshaped, weights)  # [B*d_model*num_patch, C]
        
        # Reshape back
        fused = fused.reshape(B, d_model, num_patch, C).permute(0, 3, 1, 2)
        
        return fused
    
class Model(nn.Module):
    def __init__(self, configs):
        super().__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.horizon
        self.use_norm = configs.norm
        self.lookback_windows = [int(window) for window in configs.lookback_windows.split(',')]
        self.is_channel_fusion = configs.is_channel_fusion
        self.pca_denoise = configs.pca_denoise
        # lookback_windows = [self.seq_len]
        # patch_lens = [16, 24, 32]
        # strides = [16, 24, 32]
        patch_len = configs.patch_len
        stride = configs.patch_len
        padding = stride
        self.n_branchs = len(self.lookback_windows)

        # Patching and Embedding
        self.patch_embedding = PatchEmbedding(
            configs.d_model, patch_len, stride, padding, configs.dropout) 
        
        # Add channel fusion layers
        self.channel_fusions = nn.ModuleList([
            ChannelFusion(configs.c_in) for _ in range(self.n_branchs)
        ])
        # self.channel_fusion = ChannelFusion(configs.c_in)

        # Encoders
        self.encoders = nn.ModuleList([Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, attention_dropout=configs.dropout, 
                                      output_attention=False), 
                        configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    dropout=configs.dropout,
                    activation=configs.activation,
                    # num_channels=configs.c_in
                ) for _ in range(configs.e_layers)
            ],
            norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2))
        ) for _ in range(self.n_branchs)])
        self.decomp = DECOMP()

        # 计算各分支的patch数目
        patch_nums = [int((self.lookback_windows[j] - patch_len) / stride + 2) 
                     for j in range(self.n_branchs)]
        self.head_nf = [configs.d_model * p_n for p_n in patch_nums]
        self.heads = nn.ModuleList([
            FlattenHead(self.head_nf[i], self.pred_len, configs.dropout) 
            for i in range(self.n_branchs)])
        self.temporal_individual = configs.temporal_individual
        if self.temporal_individual:
        # 添加MoE路由层
            self.router = nn.Sequential(
                nn.Linear(4, 32, bias=False),  
                nn.ReLU(),
                nn.Dropout(configs.dropout),
                nn.Linear(32, self.n_branchs, bias=False),
                nn.Softmax(dim=-1)
            )
        else:
            self.router = nn.Sequential(
                nn.Linear(configs.c_in, 32, bias=False),  
                nn.ReLU(),
                nn.Dropout(configs.dropout),
                nn.Linear(32, self.n_branchs, bias=False),
                nn.Softmax(dim=-1)
            )
            

    def forecast(self, x):
        if self.pca_denoise:
            x_list = []
            # Process PCA pathway
            for i in range(B):
                # Apply PCA to each instance separately
                reduced_x = self.pca[i].fit_transform(x[i])
                x_list.append(self.pca[i].inverse_transform(reduced_x).reshape(1, N, -1))
            x = torch.cat(x_list, dim=0)  # Concatenate back to form a batch
        
        if self.temporal_individual:
            # stats = torch.cat([x.permute(0, 2, 1).std(-1), x.permute(0, 2, 1).var(-1), x.permute(0, 2, 1).mean(-1), x.permute(0, 2, 1).max(-1)[0], x.permute(0, 2, 1).min(-1)[0]], dim=-1)
            stats = torch.cat([
                x.permute(0, 2, 1).std(-1, keepdim=True),    # (B, C, 1)
                # x.permute(0, 2, 1).var(-1, keepdim=True),    # (B, C, 1)
                x.permute(0, 2, 1).mean(-1, keepdim=True),   # (B, C, 1)
                x.permute(0, 2, 1).max(-1, keepdim=True)[0], # (B, C, 1)
                x.permute(0, 2, 1).min(-1, keepdim=True)[0]  # (B, C, 1)
            ], dim=-1)  # -> (B, C, 5)
        else:
            stats = x.permute(0, 2, 1).std(-1)
        weights = self.router(stats)  # [B, C, n_branchs]
        # print(x.permute(0, 2, 1).std(-1)[0], x.permute(0, 2, 1).var(-1)[0], x.permute(0, 2, 1).mean(-1)[0], x.permute(0, 2, 1).max(-1)[0][0], x.permute(0, 2, 1).min(-1)[0][0])
        # print(weights[0])
        
        if self.use_norm:
            means = x.mean(1, keepdim=True).detach()
            x = x - means
            stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False)) + 1e-5
            x /= stdev

        x = x.permute(0, 2, 1)  # [B, C, L]
        # 拼接所有分支的特征并路由
        
        dec_outs = []
        streams = [torch.cuda.Stream() for _ in range(self.n_branchs)]
        for i in range(self.n_branchs):
            with torch.cuda.stream(streams[i]):
                # 获取每个分支的嵌入和编码
                enc_out, n_vars = self.patch_embedding(x[:, :, -self.lookback_windows[i]:])
                enc_out, _ = self.encoders[i](enc_out)
                
                # 重塑并调整维度
                enc_out = enc_out.reshape(-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
                enc_out = enc_out.permute(0, 1, 3, 2)  # [B, C, d_model, num_patch]
                # print(enc_out.shape)
                if self.is_channel_fusion:
                    enc_out = self.channel_fusions[i](enc_out)
                dec_out = self.heads[i](enc_out)  # [B, C, L]
                dec_outs.append(dec_out)
        # print(self.channel_fusions[0].weights)
        # 堆叠并加权求和
        torch.cuda.synchronize()
        stacked_dec = torch.stack(dec_outs, dim=-1)  # [B, C, L, n_branchs]
        # print(weights.shape)
        if self.temporal_individual:
            weighted_dec = (stacked_dec * weights.unsqueeze(2)).sum(dim=-1)  # [B, C, L]
        else:
            weighted_dec = (stacked_dec * weights.unsqueeze(1).unsqueeze(2)).sum(dim=-1)  # [B, C, L]
        
        
        # 调整维度并反归一化
        final_out = weighted_dec.permute(0, 2, 1)  # [B, L, C]
        if self.use_norm:
            final_out = final_out * stdev[:, 0, :].unsqueeze(1) + means[:, 0, :].unsqueeze(1)
        return final_out

    def forward(self, x):
        dec_out = self.forecast(x)
        return dec_out[:, -self.pred_len:, :]