from __future__ import annotations
import torch
import torch.nn as nn

class SeqPredGRU(nn.Module):
    def __init__(self, obs_dim: int, feat_dim: int = 10, hidden: int = 128, k: int = 3):
        super().__init__()
        # Deeper obs encoder
        self.obs_enc = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.Tanh()
        )
        # Deeper feat encoder
        self.feat_enc = nn.Sequential(
            nn.Linear(feat_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.Tanh()
        )
        # Deeper GRU: 2 layers
        self.gru = nn.GRU(hidden, hidden, num_layers=2, batch_first=True)
        # Deeper head
        self.head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, obs_dim)
        )
        self.k = k
        self.hidden_size = hidden

    def forward(self, obs0: torch.Tensor, feats: torch.Tensor):
        # obs0: [B, obs_dim]
        # feats: [B, k, feat_dim]
        h0 = self.obs_enc(obs0)  # [B,H]
        f = self.feat_enc(feats) # [B,k,H]
        # h0 for multi-layer GRU: [num_layers, B, H]
        h0_stack = h0.unsqueeze(0).repeat(self.gru.num_layers, 1, 1)
        out, _ = self.gru(f, h0_stack)
        preds = self.head(out)               # [B,k,obs_dim]
        return preds, out
