import torch.nn as nn

def init_lstm(lstm):
    for name, param in lstm.named_parameters():
        if 'weight_ih' in name:
            # Input-to-hidden weights
            nn.init.xavier_uniform_(param.data)
        elif 'weight_hh' in name:
            # Hidden-to-hidden weights
            nn.init.orthogonal_(param.data)
        elif 'bias' in name:
            param.data.fill_(0)
            # Set forget gate bias to 1 (LSTM only)
            n = param.size(0)
            param.data[n//4:n//2].fill_(1.0)

def init_gru(gru):
    for name, param in gru.named_parameters():
        if 'weight_ih' in name:
            nn.init.xavier_uniform_(param.data)
        elif 'weight_hh' in name:
            nn.init.orthogonal_(param.data)
        elif 'bias' in name:
            param.data.fill_(0)
