import torch
import torch.nn as nn


class AttentionLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim) -> None:
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim, num_heads, bias=False, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        q = k = v = self.norm1(x)
        out1 = self.self_attn(q, k, v, need_weights=False)[0]
        out1 = out1 + x

        out2 = self.norm2(out1)
        out2 = self.ffn(out2)
        out2 = out2 + out1
        return out2


class PromptGenerator(nn.Module):
    def __init__(self, num_layers, num_query, embed_dim=256, return_prompts=False):
        super(PromptGenerator, self).__init__()
        self.num_query = num_query
        self.return_prompts = return_prompts
        self.additional_query = nn.Parameter(torch.randn(1, num_query, embed_dim), requires_grad=True)
        self.layers = nn.ModuleList(
            [AttentionLayer(embed_dim, num_heads=8, hidden_dim=1024) for _ in range(num_layers)])

    def forward(self, x):
        B = x.shape[0]
        # expand tokens
        extra_token = self.additional_query.expand(B, -1, -1)  # [B,N,C]
        tokens = torch.cat([x, extra_token], dim=1)  # [B,N+1,C]
        for i, layer in enumerate(self.layers):
            tokens = layer(tokens)
        return tokens if self.return_prompts else tokens[:, 1:, :]


if __name__ == '__main__':
    model = PromptGenerator(3, 128)
    # model = LightGenerator(3, (32, 32))
    x = torch.randn((2, 1, 256))
    y = model(x)
    print(y.shape)
    import pdb;

    pdb.set_trace()
