import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.init import trunc_normal_

class PerceiverLayer(nn.Module):
    def __init__(self, hidden_size, lang_dim, target_len):
        super().__init__()
        self.num_queries = target_len
        self.embed_dim = lang_dim
        self.num_heads = lang_dim // 128 

        self.pos_embed = nn.Parameter(
            torch.from_numpy(positional_encoding(lang_dim, self.num_queries)).float()
        ).requires_grad_(False)

        self.query = nn.Parameter(torch.zeros(self.num_queries, lang_dim))
        trunc_normal_(self.query, std=.02)

        self.kv_proj = nn.Linear(hidden_size, lang_dim) 

        self.attn = nn.MultiheadAttention(lang_dim, self.num_heads)
        self.ln_q = nn.LayerNorm(lang_dim, eps=1e-12)
        self.ln_kv = nn.LayerNorm(lang_dim, eps=1e-12)
        self.projection = nn.Sequential(
            nn.LayerNorm(lang_dim, eps=1e-12), 
            nn.Linear(lang_dim, lang_dim) # an example implementation
            )
        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, attn_mask=None):
        pos_embed = get_abs_pos(self.pos_embed, x.size(1))

        x = self.kv_proj(x)
        x = self.ln_kv(x).permute(1, 0, 2)  #(L,B,D)

        N = x.shape[1]
        q = self.ln_q(self.query)
        out = self.attn(
            self._repeat(q, N) + self.pos_embed.unsqueeze(1),
            x + pos_embed.unsqueeze(1),
            x,
            attn_mask=attn_mask)[0]
        return self.projection(out.permute(1, 0, 2))

    def _repeat(self, query, N: int):
        return query.unsqueeze(1).repeat(1, N, 1)
    
def get_abs_pos(abs_pos, tgt_size):
    src_size = int(abs_pos.size(0))
    tgt_size = int(tgt_size)
    dtype = abs_pos.dtype

    if src_size != tgt_size:
        return F.interpolate(
            abs_pos.float().reshape(1, src_size, -1).permute(0, 2, 1),
            size=(tgt_size),
            mode="linear",
            align_corners=False,
        ).permute(0, 2, 1).flatten(0, 1).to(dtype=dtype)
    else:
        return abs_pos

def positional_encoding(d_model,max_len):
    position = np.arange(max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    pe = np.zeros((max_len, d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe
