"""Alternative recurrent architectures (LSTM, GTrXL, Mamba, minGRU)."""

import math
import torch
import torch.nn as nn
from mambapy.mamba import Mamba, MambaConfig
from minGRU_pytorch import minGRU


class LSTMPolicy(nn.Module):
    """LSTM policy, drop-in replacement for GRUPolicy."""

    def __init__(self, obs_dim, hidden_dim=64, n_actions=3, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_actions = n_actions
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
        )
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.policy_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, obs_seq, hidden=None):
        B, T, _ = obs_seq.shape
        encoded = self.obs_encoder(obs_seq)
        if hidden is None:
            h0 = torch.zeros(1, B, self.hidden_dim, device=obs_seq.device)
            c0 = torch.zeros(1, B, self.hidden_dim, device=obs_seq.device)
            hidden = (h0, c0)
        lstm_out, hidden = self.lstm(encoded, hidden)
        logits = self.policy_head(lstm_out)
        return logits, hidden

    def get_recurrent_output(self, obs_seq):
        encoded = self.obs_encoder(obs_seq)
        lstm_out, _ = self.lstm(encoded)
        return lstm_out


class _SinusoidalPE(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class _GTrXLBlock(nn.Module):
    """Single GTrXL block: pre-LN causal MHA + GRU gate + FFN."""

    def __init__(self, d_model, n_heads, d_inner, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.gate_attn = nn.GRUCell(d_model, d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_inner),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )
        self.gate_ffn = nn.GRUCell(d_model, d_model)

    def forward(self, x, attn_mask):
        B, T, D = x.shape
        normed = self.ln1(x)
        attn_out, _ = self.attn(normed, normed, normed, attn_mask=attn_mask)
        x_flat = x.reshape(B * T, D)
        attn_flat = attn_out.reshape(B * T, D)
        x = self.gate_attn(attn_flat, x_flat).reshape(B, T, D)
        normed = self.ln2(x)
        ffn_out = self.ffn(normed)
        x_flat = x.reshape(B * T, D)
        ffn_flat = ffn_out.reshape(B * T, D)
        x = self.gate_ffn(ffn_flat, x_flat).reshape(B, T, D)
        return x


class GTrXLPolicy(nn.Module):
    """Gated Transformer-XL policy (Parisotto et al. 2020)."""

    def __init__(self, obs_dim, hidden_dim=64, n_actions=3, dropout=0.1,
                 n_layers=2, n_heads=2, d_inner=128):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_actions = n_actions
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
        )
        self.pos_enc = _SinusoidalPE(hidden_dim)
        self.blocks = nn.ModuleList([
            _GTrXLBlock(hidden_dim, n_heads, d_inner, dropout) for _ in range(n_layers)
        ])
        self.ln_out = nn.LayerNorm(hidden_dim)
        self.policy_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def _causal_mask(self, T, device):
        return nn.Transformer.generate_square_subsequent_mask(T, device=device)

    def forward(self, obs_seq, hidden=None):
        B, T, _ = obs_seq.shape
        encoded = self.obs_encoder(obs_seq)
        x = self.pos_enc(encoded)
        mask = self._causal_mask(T, obs_seq.device)
        for block in self.blocks:
            x = block(x, mask)
        x = self.ln_out(x)
        logits = self.policy_head(x)
        return logits, None

    def get_recurrent_output(self, obs_seq):
        B, T, _ = obs_seq.shape
        encoded = self.obs_encoder(obs_seq)
        x = self.pos_enc(encoded)
        mask = self._causal_mask(T, obs_seq.device)
        for block in self.blocks:
            x = block(x, mask)
        return self.ln_out(x)


class MambaPolicy(nn.Module):

    def __init__(self, obs_dim, hidden_dim=64, n_actions=3, dropout=0.1,
                 n_layers=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_actions = n_actions
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
        )
        cfg = MambaConfig(d_model=hidden_dim, n_layers=n_layers)
        self.mamba = Mamba(cfg)
        self.policy_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, obs_seq, hidden=None):
        encoded = self.obs_encoder(obs_seq)
        mamba_out = self.mamba(encoded)
        logits = self.policy_head(mamba_out)
        return logits, None

    def get_recurrent_output(self, obs_seq):
        encoded = self.obs_encoder(obs_seq)
        return self.mamba(encoded)


class MinGRUPolicy(nn.Module):
    """minGRU policy (Feng et al. 2024)."""

    def __init__(self, obs_dim, hidden_dim=64, n_actions=3, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_actions = n_actions
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mingru = minGRU(dim=hidden_dim)
        self.policy_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )

    def forward(self, obs_seq, hidden=None):
        encoded = self.obs_encoder(obs_seq)
        mingru_out = self.mingru(encoded)
        logits = self.policy_head(mingru_out)
        return logits, None

    def get_recurrent_output(self, obs_seq):
        encoded = self.obs_encoder(obs_seq)
        return self.mingru(encoded)


ARCH_REGISTRY = {
    "lstm": LSTMPolicy,
    "gtrxl": GTrXLPolicy,
    "mamba": MambaPolicy,
    "mingru": MinGRUPolicy,
}
