import torch
import torch.nn as nn

class LearnableLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
        # Learnable initial hidden and cell states
        self.h0 = nn.Parameter(torch.randn(1, 1, hidden_size))
        self.c0 = nn.Parameter(torch.randn(1, 1, hidden_size))

    def forward(self, x):
        batch_size = x.shape[0]
        h0 = self.h0.repeat(1, batch_size, 1)
        c0 = self.c0.repeat(1, batch_size, 1)
        x, (hn, cn) = self.lstm(x, (h0, c0))
        return x
