import torch.nn as nn
import torch.nn.functional as F
import torch
from layers.HSM_SSD import CHSMSSD1D,HSMSSD1D

class Model(nn.Module):
    """
    Just one Linear layer
    """
    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        # Use this line if you want to visualize the weights
        # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.k = configs.k
        self.r = configs.r
        self.linear1_list = nn.ModuleList([nn.Linear(self.seq_len, self.seq_len + self.pred_len).to(device) for _ in range(self.k)])
        self.linear2_list = nn.ModuleList([nn.Linear(self.r, self.seq_len + self.pred_len).to(device) for _ in range(self.k)])
        self.linear3_list = nn.ModuleList([nn.Linear(self.seq_len + self.pred_len, self.r).to(device) for _ in range(self.k)])

        self.channel_z = nn.ModuleList([
            channel_z(
                configs,
                in_planes=configs.enc_in,
                r=self.r,
                d_model=configs.d_model,
                dropout1=configs.dropout,
                dropout2=configs.fc_dropout,
                pre_norm=configs.pre_norm
            ).to(device) for _ in range(self.k)
        ])

        self.rou_list = nn.ParameterList([nn.Parameter(torch.randn(1)) for _ in range(self.k)])
        self.Linear_R = nn.Linear(self.r, self.pred_len)
        self.dropout = nn.Dropout(configs.head_dropout)
        self.Linear_Y_PROX = nn.ModuleList([nn.Linear(self.r, self.r).to(device) for _ in range(self.k)])

    def forward(self, x, batch_x_mark, dec_inp, batch_y_mark):
        seq_last = x[:, -1:, :].detach()
        x = x - seq_last
        x0 = x
        # 使用 GPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # x: [Batch, Input length, Channel]
        B = x.shape[0]
        R = x.shape[1]  #
        N = x.shape[2]  #
        r = self.r   #
        Y_input = x.permute(0, 2, 1)  # x_input: [Batch, Channel, Input length]

        #
        E = torch.zeros(B, R + self.pred_len, R, device=device)
        eye = torch.eye(R, device=device).unsqueeze(0).expand(B, -1, -1)  # [B, R, R]
        E[:, :R, :] = eye  #

        Y_PROX = torch.zeros(B, N, r, device=device)

        X_list = []
        K = self.k
        for k in range(K):
            Y_A = F.relu(self.linear1_list[k](Y_input))
            Y_p_B = F.relu(self.linear2_list[k](Y_PROX))
            X = Y_A + Y_p_B   # (YA+Y'B)
            Y = torch.bmm(X, E)  # (YA+Y'B)E
            X_noise = Y - Y_input  # (YA+Y'B)E - Y
            XE_T = torch.bmm(X_noise, E.permute(0, 2, 1))
            Z = (self.linear3_list[k](XE_T))  # * E_T * B_t
            Y_PROX = Y_PROX - self.rou_list[k] * Z  # Y - rou * g'
            Y_PROX = self.channel_z[k](Y_PROX)  # prox
            X_list.append(X.detach())  #  X_i

        x = self.Linear_R(Y_PROX).permute(0, 2, 1)
        x = self.dropout(x)
        x = x + seq_last
        return x, X_list  # [Batch, Output length, Channel]



class t_ssd(nn.Module):
    def __init__(self, configs, in_planes, r, d_model, dropout1, dropout2, pre_norm, ratio=2):
        super(t_ssd, self).__init__()
        # Patching
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        patch_num = int((r - configs.patch_len) / configs.stride + 1)
        self.padding_patch_layer = nn.ReplicationPad1d((0, configs.stride))
        patch_num += 1
        self.head_nf = d_model * patch_num

        self.mixer_c = CHSMSSD1D(d_model=d_model, ssd_expand=1).to('cuda')



    def forward(self, z, nvars):  # x: [B, C, T]

        z = torch.reshape(z, (z.shape[0] * z.shape[1], z.shape[2], z.shape[3]))
        # norm & mamba
        z0 = z
        z, h = self.mixer_c(z.permute(0,2,1))   # [b*c,t,n]
        z = z.permute(0, 2, 1)  # [b*c,n,t]
        z = z + z0
        # restore
        z = torch.reshape(z, (-1, nvars, z.shape[-2], z.shape[-1])) # [b,c,n,t]

        return z

class c_ssd(nn.Module):
    def __init__(self, configs, in_planes, r, d_model, dropout1, dropout2, pre_norm, ratio=2, kernel_size=7):
        super(c_ssd, self).__init__()
        # Patching
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        patch_num = int((r - configs.patch_len) / configs.stride + 1)
        self.padding_patch_layer = nn.ReplicationPad1d((0, configs.stride))
        patch_num += 1
        self.head_nf = d_model * patch_num

        # channel CBAM
        self.mixer = HSMSSD1D(d_model=d_model, ssd_expand=1, state_dim=32)
        self.patch_num = patch_num
        self.embedding_c = nn.Linear(in_planes, configs.d_model_c)
        self.embedding_c_T = nn.Linear(configs.d_model_c, in_planes)

    def forward(self, x):  # x: [B, C, T]
        ##################################################################
        # channel CBAM
        z = x.permute(0, 2, 1, 3)  # [B,N,C,T]
        z = torch.reshape(z, (z.shape[0] * z.shape[1], z.shape[2], z.shape[3]))  # [B*N,C,T]
        z = self.embedding_c(z.permute(0, 2, 1))  # [B*N,T,C]
        z = z.permute(0, 2, 1)   # [B*N,C,T]
        z0 = z    # [B*N,C,T]
        z, h = self.mixer(z.permute(0, 2, 1))
        z = z.permute(0, 2, 1)
        z = z + z0
        z = self.embedding_c_T(z.permute(0, 2, 1))
        z = z.permute(0, 2, 1)
        z = torch.reshape(z, (-1, self.patch_num, z.shape[-2], z.shape[-1]))  # [B,N,C,T]
        z = z.permute(0, 2, 1, 3)    # [B,C,N,T]
        return z

class channel_z(nn.Module):
    def __init__(self, configs, in_planes, r, d_model, dropout1, dropout2, pre_norm, ratio=2, kernel_size=7):
        super(channel_z, self).__init__()
        # Patching
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        patch_num = int((r - configs.patch_len) / configs.stride + 1)
        self.padding_patch_layer = nn.ReplicationPad1d((0, configs.stride))
        patch_num += 1
        self.head_nf = patch_num * d_model
        self.embed = nn.Linear(self.patch_len, d_model)

        # model
        self.t_ssd = t_ssd(configs, in_planes, r, d_model, dropout1, dropout2, pre_norm, ratio=2)
        self.c_ssd = c_ssd(configs, in_planes, r, d_model, dropout1, dropout2, pre_norm, ratio=2)

        # concat
        hidden_dim = configs.d_ff  # 256
        self.norm1 = nn.LayerNorm(d_model*2)
        self.ff = nn.Sequential(nn.Linear(d_model*2, hidden_dim, bias=True),
                                nn.ReLU(),
                                nn.Dropout(dropout2),
                                nn.Linear(hidden_dim, d_model, bias=True))
        # restore
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(self.head_nf, r)
        self.dropout = nn.Dropout(dropout1)

    def forward(self, x):  # x: [B, C, T]

        input = x

        # # Time: patch & mamba(no res)
        # do patch & embedding
        z = self.padding_patch_layer(input)
        z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        z = self.embed(z)
        z_patch = z  # [b,c,n,t]

        x_t = self.t_ssd(z, input.shape[1])
        x_c = self.c_ssd(z_patch)

        x = torch.concat((x_c, x_t), dim=-1)
        z0 = x
        x = self.norm1(x)
        z = self.ff(x)#+z0

        # head
        x = self.flatten(z)
        x = self.linear(x)
        x = self.dropout(x)

        return x


class GatedFusion(nn.Module):
    def __init__(self, dim):
        """
        Gated Fusion Module
        :param dim: 特征维度（两个输入特征的维度必须相同）
        """
        super(GatedFusion, self).__init__()
        self.W1 = nn.Linear(dim, dim)
        self.W2 = nn.Linear(dim, dim)
        self.bias = nn.Parameter(torch.zeros(dim))  # 可学习偏置
        self.sigmoid = nn.Sigmoid()

    def forward(self, x1, x2):
        """
        :param x1: [B, D] 或 [B, T, D]，第一路特征（如通道分支输出）
        :param x2: [B, D] 或 [B, T, D]，第二路特征（如时间分支输出）
        :return: 融合后的特征 [B, D] 或 [B, T, D]
        """
        # 计算门控系数 g ∈ [0, 1]
        g = self.sigmoid(self.W1(x1) + self.W2(x2) + self.bias)

        # Gated Fusion: g * x1 + (1 - g) * x2
        z = g * x1 + (1 - g) * x2
        return z


def create_patches(x: torch.Tensor, patch_len: int) -> torch.Tensor:
    """
    将输入序列沿时间轴切成非重叠 patch
    输入: x ∈ [B, C, T]
    输出: patches ∈ [B, C, N, patch_len]
    """
    B, C, T = x.shape
    assert T % patch_len == 0, "T must be divisible by patch_len"
    N = T // patch_len
    x = x.view(B * C, N, patch_len)
    return x  # [B, C, N, patch_len]
