##
## (c) Anonymous authors (2026)
##
## > Elman-type recurrent neural network (RNN)
##
##

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import multiprocessing
import pandas as pd
import os
import tqdm


class ElmanRNN(nn.Module):

    """

    Elman-type RNN

        input_dim: input dimension
        hidden_dim: width of the network

    """

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Initialization of the network weights
        self.W = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.U = nn.Parameter(torch.randn(hidden_dim, input_dim))
        self.reset_parameters()

    def reset_parameters(self):
        """
        Reset of the network weights
        """
        nn.init.xavier_uniform_(self.W)
        nn.init.xavier_uniform_(self.U)

    def forward(self, inputs, hidden=None):
        """
        Forward pass
        """
        outputs = []
        if hidden is None:
            hidden = torch.zeros(inputs.size(1), self.hidden_dim)
        for t in range(inputs.size(0)):
            hidden = torch.tanh(F.linear(hidden, self.W) + F.linear(inputs[t], self.U))
            outputs.append(hidden)
        return torch.stack(outputs), hidden
