import torch
import torch.nn as nn


class RNNDistributionPredictor(nn.Module):
    def __init__(self, input_size, hidden_size, d, rnn_layers, dropout):
        super().__init__()
        self.rnn = nn.RNN(
            input_size,
            hidden_size,
            num_layers=rnn_layers,
            nonlinearity="tanh",
            bias=True,
            batch_first=True,
            dropout=dropout,
            bidirectional=True if d == 2 else False,
        )
        self.linear = nn.Linear(hidden_size * 2, 1)
        self.d = d
        self.hidden_size = hidden_size
        self.rnn_layers = rnn_layers

    def forward(self, x):
        h = torch.zeros((self.d * self.rnn_layers, x.shape[0], self.hidden_size), requires_grad=False, device=x.device)
        rnn_out, h = self.rnn(x, h)
        x_out = self.linear(rnn_out)
        return x_out.squeeze(-1)


class AttentionDistributionPredictor(nn.Module):
    def __init__(self, hid_dim, num_heads):
        super().__init__()
        self.QKV = nn.Linear(hid_dim, hid_dim * 3, bias=False)
        self.hid_dim = hid_dim
        self.self_attention = nn.MultiheadAttention(embed_dim=hid_dim, num_heads=num_heads, batch_first=True)
        self.logit_layer = nn.Linear(hid_dim, 1)

    def forward(self, x):
        # x.shape = [bs, n_layers, hid_dim]
        q, k, v = torch.split(self.QKV(x), self.hid_dim, dim=2)
        # q.shape = k.shape = v.shape = [bs, n_layers, hid_dim]
        attn_output, _ = self.self_attention(q, k, v)
        # logits.shape = [bs, n_layers]
        logits = self.logit_layer(attn_output).squeeze(-1)
        return logits
