# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""LSTM layers module."""

from torch import nn


class SLSTM(nn.Module):
    """
    LSTM without worrying about the hidden state, nor the layout of the data.
    Expects input as convolutional layout.
    """
    def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
        super().__init__()
        self.skip = skip
        self.lstm = nn.LSTM(dimension, dimension, num_layers)

    # def forward(self, x):
    #     x = x.permute(2, 0, 1)
    #     y, _ = self.lstm(x)
    #     if self.skip:
    #         y = y + x
    #     y = y.permute(1, 2, 0)
    #     return y

    # 修改transpose顺序
    def forward(self, x):
        # # 插入reshape
        # x = x.reshape(x.shape)
        x1 = x.permute(2, 0, 1)
        y, _ = self.lstm(x1)
        y = y.permute(1, 2, 0)
        if self.skip:
            y = y + x
        return y
