import torch
import torch.nn as nn
from layers.PatchTST_layers import *
from Mamba_simple.mamba_poly import Mamba, ModelArgs


class Model(nn.Module):
    """
    S-Mamba without GPU acceleration
    Mamba: [B C D]->[B C D]
    Paper link: https://arxiv.org/abs/2310.06625
    """

    def __init__(self, configs):
        super(Model, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.output_attention = configs.output_attention
        self.use_norm = configs.use_norm

        # Patching
        self.patch_len = configs.patch_len
        self.stride = configs.stride
        self.padding_patch = configs.padding_patch
        self.patch_num = int((self.seq_len - self.patch_len)/self.stride + 1)
        if self.padding_patch == 'end':  # can be modified to general case
            self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
            self.patch_num += 1

        self.class_strategy = configs.class_strategy
        # args
        args = ModelArgs(
            d_model=configs.d_model,
            n_layer=1,
            vocab_size=configs.d_model,  # 假设词汇表大小为 10000
            d_state=configs.d_state,
            expand=1,
            dt_rank='auto',
            d_conv=2,
            pad_vocab_size_multiple=8,
            conv_bias=True,
            bias=False,
            d_C=configs.enc_in,
        )
        # Encoder-only architecture
        self.encoder1 = Mamba(args)
        self.encoder2 = Mamba(args)
        self.encoder3 = Mamba(args)
        self.encoder4 = Mamba(args)

    # a = self.get_parameter_number()
    #
    # def get_parameter_number(self):
    #     """
    #     Number of model parameters (without stable diffusion)
    #     """
    #     total_num = sum(p.numel() for p in self.parameters())
    #     trainable_num = sum(p.numel() for p in self.parameters() if p.requires_grad)
    #     trainable_ratio = trainable_num / total_num
    #
    #     print('total_num:', total_num)
    #     print('trainable_num:', total_num)
    #     print('trainable_ratio:', trainable_ratio)

        self.embedding_P = nn.Linear(self.patch_len, configs.d_model)
        self.W_pos = positional_encoding(pe='zero', learn_pe=True, q_len=self.patch_num, d_model=configs.d_model)
        # parser.add_argument('--Em_dropout', type=float, default=0.05, help='EM_dropout')
        self.dropout = nn.Dropout(configs.EM_dropout)
        self.projector_P = nn.Linear(self.patch_num * configs.d_model, self.pred_len, bias=True)

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):  # input B L C
        # Normalization [B L C] -> [B L C]
        if self.use_norm:
            # Normalization from Non-stationary Transformer
            means = x_enc.mean(1, keepdim=True).detach()
            x_enc = x_enc - means
            stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
            x_enc /= stdev

        # Patching [B L C] -> [B*C N P]
        z = x_enc.permute(0, 2, 1)
        if self.padding_patch == 'end':
            z = self.padding_patch_layer(z)
        z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride)

        # Embedding [B C N P] -> [B*C N D]
        z = self.embedding_P(z)
        B, C, N, D = z.shape
        # z = z.permute(0, 2, 3, 1)
        z = torch.reshape(z, (z.shape[0]*z.shape[1],z.shape[2],z.shape[3]))      # u: [bs * nvars x patch_num x d_model]
        z = self.dropout(z + self.W_pos)

        # Mamba [B*C N D] -> [B*C N D]
        z = self.encoder1(z)
        z = self.encoder2(z)
        z = self.encoder3(z)
        z = self.encoder4(z)

        # Projector [B*C N D] -> [B C T]
        z = z.reshape(B, C, N, D).flatten(start_dim=2, end_dim=3)
        dec_out = self.projector_P(z).permute(0,2,1)

        if self.use_norm:
            # De-Normalization from Non-stationary Transformer
            dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
            dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))

        return dec_out

